From 302aa809ead97ba0108c845bd851b3c871ea2909 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Fri, 12 Sep 2025 20:11:58 +0000 Subject: [PATCH] Merge commit 'b0ee317d83b77741022997265d4125697e7f7804' into develop --- Dockerfile.aiter | 6 +- Jenkinsfile | 46 +- example/35_splitK_gemm/CMakeLists.txt | 13 + example/35_splitK_gemm/common.hpp | 82 +++ .../gemm_wmma_splitk_reduce_bf16.cpp | 59 ++ .../gemm_wmma_splitk_reduce_bf16A_i8B.cpp | 59 ++ .../gemm_wmma_splitk_reduce_multi_d_bf16.cpp | 59 ++ .../gemm_wmma_splitk_reduce_multi_d_fp16.cpp | 59 ++ ...run_gemm_splitk_reduce_multi_d_example.inc | 82 --- .../run_gemm_wmma_splitk_reduce_example.inc | 191 ++++++ ...emm_wmma_splitk_reduce_multi_d_example.inc | 214 +++++++ example/ck_tile/01_fmha/CMakeLists.txt | 11 +- example/ck_tile/03_gemm/run_gemm_example.inc | 8 +- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 8 +- .../ck_tile/18_flatmm/run_flatmm_example.inc | 8 +- include/ck/host_utility/device_prop.hpp | 5 + .../impl/device_gemm_wmma_cshuffle_v3r1.hpp | 562 ++++++++++++++++++ .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 12 + ...p_gemm_attribute_wmma_impl_base_traits.hpp | 4 +- .../gpu/gemm_universal_reduce.hpp | 72 ++- .../gpu/gemm_universal_reduce/CMakeLists.txt | 10 +- ...wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp | 72 +++ ...16_bf16_mk_kn_mn_comp_default_instance.cpp | 58 ++ ...m_wmma_universal_bf16_i8_bf16_mk_kn_mn.hpp | 73 +++ ...i8_bf16_mk_kn_mn_comp_default_instance.cpp | 59 ++ ...mm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp | 72 +++ ...f16_f16_mk_kn_mn_comp_default_instance.cpp | 57 ++ .../profile_gemm_universal_reduce_impl.hpp | 18 +- profiler/src/CMakeLists.txt | 4 +- test/CMakeLists.txt | 1 + .../add_rmsnorm2d_rdquant/CMakeLists.txt | 2 +- test/ck_tile/batched_gemm/CMakeLists.txt | 3 +- .../batched_gemm/test_batched_gemm_util.hpp | 44 +- test/ck_tile/batched_transpose/CMakeLists.txt | 3 +- test/ck_tile/container/CMakeLists.txt | 2 +- test/ck_tile/data_type/CMakeLists.txt | 2 +- test/ck_tile/elementwise/CMakeLists.txt | 2 +- .../elementwise/test_elementwise_1d.cpp | 2 +- test/ck_tile/fmha/test_fmha_fwd.inc | 2 +- test/ck_tile/gemm/CMakeLists.txt | 70 ++- .../test_gemm_pipeline_basic_run_test.inc | 65 +- .../gemm/test_gemm_pipeline_smoke_util.hpp | 21 + .../test_gemm_pipeline_universal_run_test.inc | 8 + test/ck_tile/gemm_multi_d/CMakeLists.txt | 3 +- .../gemm_multi_d/test_gemm_multi_d_util.hpp | 53 +- .../gemm_weight_preshuffle/CMakeLists.txt | 2 +- .../test_gemm_pipeline_kernel_types.hpp | 6 +- .../test_gemm_pipeline_util.hpp | 86 ++- test/ck_tile/grouped_gemm/CMakeLists.txt | 2 +- .../grouped_gemm/test_grouped_gemm_util.hpp | 33 +- test/ck_tile/image_to_column/CMakeLists.txt | 3 +- test/ck_tile/layernorm2d/CMakeLists.txt | 2 +- test/ck_tile/moe_smoothquant/CMakeLists.txt | 3 +- test/ck_tile/moe_sorting/CMakeLists.txt | 4 +- test/ck_tile/permute/CMakeLists.txt | 3 +- test/ck_tile/permute/test_permute_util.hpp | 4 + test/ck_tile/reduce/CMakeLists.txt | 2 +- test/ck_tile/reduce/test_reduce2d.cpp | 2 +- test/ck_tile/rmsnorm2d/CMakeLists.txt | 2 +- test/ck_tile/smoothquant/CMakeLists.txt | 3 +- test/ck_tile/topk_softmax/CMakeLists.txt | 3 +- test/gemm_universal_reduce/CMakeLists.txt | 14 + ...st_gemm_universal_reduce_bf16A_i8_wmma.cpp | 31 + .../test_gemm_universal_reduce_bf16_wmma.cpp | 31 + .../test_gemm_universal_reduce_fp16_wmma.cpp | 31 + 65 files changed, 2301 insertions(+), 232 deletions(-) create mode 100644 example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16.cpp create mode 100644 example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16A_i8B.cpp create mode 100644 example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_bf16.cpp create mode 100644 example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp create mode 100644 example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_example.inc create mode 100644 example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_multi_d_example.inc create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp create mode 100644 test/gemm_universal_reduce/CMakeLists.txt create mode 100644 test/gemm_universal_reduce/test_gemm_universal_reduce_bf16A_i8_wmma.cpp create mode 100644 test/gemm_universal_reduce/test_gemm_universal_reduce_bf16_wmma.cpp create mode 100644 test/gemm_universal_reduce/test_gemm_universal_reduce_fp16_wmma.cpp diff --git a/Dockerfile.aiter b/Dockerfile.aiter index 245e39fb75..b61c1e41a5 100644 --- a/Dockerfile.aiter +++ b/Dockerfile.aiter @@ -1,10 +1,8 @@ -ARG BASE_DOCKER="rocm/pytorch:latest" +ARG BASE_DOCKER="rocm/composable_kernel-private:ck_aiter_base" FROM $BASE_DOCKER ARG AITER_BRANCH="main" ARG CK_AITER_BRANCH="develop" -RUN groupadd -g 109 render && \ - usermod -u 1001 jenkins && \ - groupmod -g 1001 jenkins && \ +RUN groupadd irc && \ pip install pandas zmq einops && \ pip install numpy==1.26.2 && \ sudo mkdir /home/jenkins && \ diff --git a/Jenkinsfile b/Jenkinsfile index 87654328b6..9d1af7c5d9 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -149,7 +149,7 @@ def getDockerImage(Map conf=[:]){ image = conf.get("docker_name", "") echo "Using legacy docker: ${image}" } - else if ( params.BUILD_GFX950 && conf.get("docker_name", "") != "" ){ + else if ( (params.BUILD_GFX950 || params.RUN_CK_TILE_FMHA_TESTS) && conf.get("docker_name", "") != "" ){ image = conf.get("docker_name", "") echo "Using special docker: ${image}" } @@ -186,11 +186,11 @@ def buildDocker(install_prefix){ dockerArgs = dockerArgs + " --no-cache --build-arg BASE_DOCKER='${base_image_name}' -f Dockerfile.compiler . " } else if(params.RUN_AITER_TESTS){ - image_name = "rocm/composable_kernel:ck_aiter" + image_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_aiter" dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . " } else if(params.RUN_PYTORCH_TESTS){ - image_name = "rocm/composable_kernel:ck_pytorch" + image_name = "${env.CK_DOCKERHUB}:ck_pytorch" dockerArgs = dockerArgs + " --no-cache -f Dockerfile.pytorch --build-arg CK_PYTORCH_BRANCH='${params.ck_pytorch_branch}' . " } else{ @@ -716,7 +716,7 @@ def process_results(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 checkout scm //use older image that has user jenkins - def image = "rocm/composable_kernel:ck_ub22.04_rocm6.3" + def image = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm6.3" def prefixpath = "/opt/rocm" // Jenkins is complaining about the render group @@ -827,7 +827,7 @@ def run_aiter_tests(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 checkout scm //use the latest pytorch image - def image = "rocm/composable_kernel:ck_aiter" + def image = "${env.CK_DOCKERHUB_PRIVATE}:ck_aiter" def dockerOpts="--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --group-add irc --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --user=jenkins -v=/var/jenkins/:/var/jenkins" def variant = env.STAGE_NAME def retimage @@ -885,7 +885,7 @@ def run_pytorch_tests(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 checkout scm //use the latest pytorch-nightly image - def image = "rocm/composable_kernel:ck_pytorch" + def image = "${env.CK_DOCKERHUB}:ck_pytorch" def dockerOpts="--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --group-add irc --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --user=jenkins -v=/var/jenkins/:/var/jenkins" def variant = env.STAGE_NAME def retimage @@ -1207,6 +1207,18 @@ pipeline { cleanWs() } } + stage("Run AITER Tests on gfx950") + { + when { + beforeAgent true + expression { params.RUN_AITER_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx950")} + steps{ + run_aiter_tests() + cleanWs() + } + } } } stage("Run Grouped Conv Large Case Tests") @@ -1321,7 +1333,7 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ - make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ + make -j128 tile_example_fmha_fwd tile_example_fmha_bwd && \ cd ../ && example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ } @@ -1330,6 +1342,26 @@ pipeline { cleanWs() } } + stage("Run CK_TILE_FMHA Tests on gfx950") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_FMHA_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx950") } + environment{ + def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0" + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx950 && \ + make -j128 tile_example_fmha_fwd tile_example_fmha_bwd && \ + cd ../ && + example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx950 """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, docker_name: docker_name, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } } } stage("Run TILE_ENGINE_GEMM Tests") diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index 904006ba36..e0476bfaad 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -27,3 +27,16 @@ add_example_executable(example_gemm_xdl_splitk_reduce_multi_d_bf16 gemm_xdl_spli add_example_executable(example_gemm_xdl_splitk_reduce_bf16A_i8B gemm_xdl_splitk_reduce_bf16A_i8B.cpp) add_example_executable(example_gemm_xdl_splitk_reduce_bfp16 gemm_xdl_splitk_reduce_bf16.cpp) + +add_custom_target(example_splitK_gemm_wmma) +add_example_executable(example_gemm_wmma_splitk_reduce_bf16 gemm_wmma_splitk_reduce_bf16.cpp) +add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_bf16) + +add_example_executable(example_gemm_wmma_splitk_reduce_bf16A_i8B gemm_wmma_splitk_reduce_bf16A_i8B.cpp) +add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_bf16A_i8B) + +add_example_executable(example_gemm_wmma_splitk_reduce_multi_d_bf16 gemm_wmma_splitk_reduce_multi_d_bf16.cpp) +add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_multi_d_bf16) + +add_example_executable(example_gemm_wmma_splitk_reduce_multi_d_fp16 gemm_wmma_splitk_reduce_multi_d_fp16.cpp) +add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_multi_d_fp16) diff --git a/example/35_splitK_gemm/common.hpp b/example/35_splitK_gemm/common.hpp index 64fadae9e5..325cc37731 100644 --- a/example/35_splitK_gemm/common.hpp +++ b/example/35_splitK_gemm/common.hpp @@ -99,3 +99,85 @@ bool parse_cmd_args(int argc, return true; } + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} diff --git a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16.cpp b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16.cpp new file mode 100644 index 0000000000..b481483d42 --- /dev/null +++ b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ReduceDataType = ck::bhalf_t; +using D0DataType = ck::bhalf_t; +using DsDataType = ck::Tuple<>; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = CLayout; +using DsLayout = ck::Tuple<>; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceWmmaGemmInstance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 128, 32, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_wmma_splitk_reduce_example.inc" + +int main(int argc, char* argv[]) { return !run_wmma_gemm_splitk_example(argc, argv); } diff --git a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16A_i8B.cpp b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16A_i8B.cpp new file mode 100644 index 0000000000..dcf4a1652d --- /dev/null +++ b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16A_i8B.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = int8_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ReduceDataType = float; +using D0DataType = ck::bhalf_t; +using DsDataType = ck::Tuple<>; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = Row; +using DsLayout = ck::Tuple<>; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceWmmaGemmInstance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 128, 32, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_wmma_splitk_reduce_example.inc" + +int main(int argc, char* argv[]) { return !run_wmma_gemm_splitk_example(argc, argv); } diff --git a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_bf16.cpp b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_bf16.cpp new file mode 100644 index 0000000000..dab308d148 --- /dev/null +++ b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_bf16.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ReduceDataType = float; +using D0DataType = ck::bhalf_t; +using DsDataType = ck::Tuple; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = CLayout; +using DsLayout = ck::Tuple; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 128, 32, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_wmma_splitk_reduce_multi_d_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_multi_d_example(argc, argv); } diff --git a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp new file mode 100644 index 0000000000..489816559d --- /dev/null +++ b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; +using ReduceDataType = float; +using D0DataType = ck::half_t; +using DsDataType = ck::Tuple; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = CLayout; +using DsLayout = ck::Tuple; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 256, 64, + 8, 8, + 16, 16, + 4, 4, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_wmma_splitk_reduce_multi_d_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_multi_d_example(argc, argv); } diff --git a/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc b/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc index 9635993d63..0b060841bf 100644 --- a/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc +++ b/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc @@ -3,88 +3,6 @@ #pragma once -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { diff --git a/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_example.inc b/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_example.inc new file mode 100644 index 0000000000..25628ef770 --- /dev/null +++ b/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_example.inc @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +bool run_wmma_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "init method: " << config.init_method << std::endl; + std::cout << "KBatch: " << KBatch << std::endl; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // device GEMM + auto device_op = DeviceWmmaGemmInstance{}; + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + std::array{}, // empty D tensors + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + std::array{}, // empty D strides + StrideC, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + // Allocate workspace for split-K reduction if needed + size_t workspace_size = device_op.GetWorkSpaceSize(argument.get()); + DeviceMem workspace_buf(workspace_size); + std::cout << "Workspace size: " << workspace_size << " bytes" << std::endl; + if(workspace_size > 0) + { + argument->p_workspace_ = workspace_buf.GetDeviceBuffer(); + std::cout << "Allocated workspace of size: " << workspace_size << " bytes" << std::endl; + } + + if(!device_op.IsSupportedArgument(argument.get())) + { + std::cout << "The runtime argument is not supported!" << std::endl; + std::cout << "Debug info:" << std::endl; + std::cout << " M=" << M << ", N=" << N << ", K=" << K << ", KBatch=" << KBatch + << std::endl; + std::cout << " StrideA=" << StrideA << ", StrideB=" << StrideB << ", StrideC=" << StrideC + << std::endl; + return false; + } + + bool pass = true; + float ave_time = 0; + + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, cde_element_op); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, false}); + + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = ck::utils::check_err(c_m_n_device_result.mData, + c_m_n_host_result.mData, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E12 / ave_time; + + float gb_per_sec = num_btype / 1.E9 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + return pass; +} + +bool run_wmma_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_wmma_gemm(problem_size, config); +} diff --git a/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_multi_d_example.inc b/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_multi_d_example.inc new file mode 100644 index 0000000000..59996655c6 --- /dev/null +++ b/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_multi_d_example.inc @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +bool run_wmma_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto StrideD0 = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + StrideD0 = f_get_default_stride(M, N, StrideD0, D0Layout{}); + + Tensor a_m_k( + f_host_tensor_descriptor(problem_size.M, problem_size.K, problem_size.StrideA, ALayout{})); + Tensor b_k_n( + f_host_tensor_descriptor(problem_size.K, problem_size.N, problem_size.StrideB, BLayout{})); + Tensor d0_m_n( + f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, D0Layout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + d0_m_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + d0_m_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + Tensor c_m_n_host_result( + f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "init method: " << config.init_method << std::endl; + std::cout << "KBatch: " << KBatch << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + constexpr auto kNum_DTensors = DsDataType::Size(); + const std::array p_ds = {d0_m_n_device_buf.GetDeviceBuffer()}; + const std::array d_strides = {problem_size.StrideC}; + + auto argument = + gemm.MakeArgumentPointer(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + p_ds, + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + problem_size.M, + problem_size.N, + problem_size.K, + problem_size.StrideA, + problem_size.StrideB, + d_strides, + problem_size.StrideC, + problem_size.KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument.get())) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + return false; + } + + auto workspace_size = gemm.GetWorkSpaceSize(argument.get()); + DeviceMem workspace_device_buf(workspace_size); + + std::cout << "Workspace size: " << workspace_size << " bytes" << std::endl; + std::cout << "Allocated workspace of size: " << workspace_size << " bytes" << std::endl; + + if(workspace_size > 0) + { + argument->p_workspace_ = workspace_device_buf.GetDeviceBuffer(); + } + + if(config.do_verification) + { + using ReferenceGemmInstanceMultiD = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstanceMultiD{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + c_m_n_host_result.ForEach( + [&](auto& self, auto idx) { c_element_op(self(idx), self(idx), d0_m_n(idx)); }); + } + + std::cout << "init method: " << config.init_method << std::endl; + std::cout << "KBatch: " << problem_size.KBatch << std::endl; + + float ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * problem_size.M * problem_size.N * problem_size.K; + std::size_t num_btype = sizeof(ADataType) * problem_size.M * problem_size.K + + sizeof(BDataType) * problem_size.K * problem_size.N + + sizeof(CDataType) * problem_size.M * problem_size.N + + sizeof(D0DataType) * problem_size.M * problem_size.N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + if(config.do_verification) + { + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + double rtol = get_rtol(); + double atol = get_atol(); + + return ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", rtol, atol); + } + + return true; +} + +int run_gemm_splitk_multi_d_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_wmma_gemm(problem_size, config); +} diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index b1e2373657..68db468a7c 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -26,7 +26,7 @@ endforeach() # "fwd" is a must-have api for the fmha_fwd example, add it if not specified if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS) - list(APPEND FMHA_FWD_ENABLE_APIS "fwd") + list(PREPEND FMHA_FWD_ENABLE_APIS "fwd") endif() file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS @@ -51,6 +51,15 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS # --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd... ) +# Reduce building time by disabling instances that are not currently used in the gtests +# TODO: Consider to use a special receipt for testing only, or even two receipts: a small subset of +# instances for quick CI runs and a larger subset for scheduled runs (the tests skip tests when +# there is no corresponding instance for parameters). +if(BUILD_TESTING) + # Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill + list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv) +endif() + # generate a list of kernels, but not actually emit files at config sta execute_process( COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index cc980a75f7..e6875f97d5 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -181,15 +181,15 @@ auto shuffle_b(const ck_tile::HostTensor& t) if(ck_tile::is_gfx12_supported()) { - // TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase constexpr int divisor = 2; - constexpr int kABK0PerLane = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, GemmConfig::N_Warp_Tile, k_ / GemmConfig::K_Warp_Tile, - divisor, kABK0PerLane, - GemmConfig::K_Warp_Tile / divisor / kABK0PerLane}); + divisor, + kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index a8abcee41e..1ae0844032 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -314,15 +314,15 @@ auto shuffle_b(const ck_tile::HostTensor& t) if(ck_tile::is_gfx12_supported()) { - // TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase constexpr int divisor = 2; - constexpr int kABK0PerLane = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, GemmConfig::N_Warp_Tile, k_ / GemmConfig::K_Warp_Tile, - divisor, kABK0PerLane, - GemmConfig::K_Warp_Tile / divisor / kABK0PerLane}); + divisor, + kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); } diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 63d0a80555..c187f72594 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -45,15 +45,15 @@ auto shuffle_b(const ck_tile::HostTensor& t) if(ck_tile::is_gfx12_supported()) { - // TODO: Please modify it once kABK0PerLane is changed in WmmaTraitsBase constexpr int divisor = 2; - constexpr int kABK0PerLane = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = FlatmmConfig::K_Warp_Tile / divisor / kABK1PerLane; ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, FlatmmConfig::N_Warp_Tile, k_ / FlatmmConfig::K_Warp_Tile, - divisor, kABK0PerLane, - FlatmmConfig::K_Warp_Tile / divisor / kABK0PerLane}); + divisor, + kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); } diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 2e949bb1df..6b04b21e4f 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -129,5 +129,10 @@ inline bool is_gfx103_supported() ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036"; } +inline bool is_wmma_supported() +{ + return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported(); +} + } // namespace ck #endif diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp new file mode 100644 index 0000000000..3a06ea8451 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp @@ -0,0 +1,562 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1 +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + Tuple<>, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + ReduceDataType, + Tuple<>, + ReduceDataType, + AElementwiseOperation, + BElementwiseOperation, + PassThrough, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, + false>; + + struct Argument : public GridwiseGemm::Argument + { + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + const ::std::array p_ds_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + const ::std::array stride_ds_, + index_t StrideC_, + index_t KBatch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + ::std::array{}, + reinterpret_cast(p_c_grid_), + M_, + N_, + K_, + StrideA_, + StrideB_, + std::array{}, + StrideC_, + KBatch_, + a_element_op_, + b_element_op_, + PassThrough{}, + true), + p_c_grid(p_c_grid_), + c_element_op(c_element_op_), + p_ds(p_ds_), + StrideDs(stride_ds_) + { + } + + CDataType* p_c_grid; + CElementwiseOperation c_element_op; + const ::std::array p_ds; + ::std::array StrideDs; + }; + + using ReduceAdd = ck::reduce::Add; + using OutElementwiseOperation = CElementwiseOperation; + + static constexpr auto DsVectorLengthSequence = generate_sequence_v2( + [](auto i) { + using DLayout = ::std::__remove_cvref_t>; + if constexpr(is_same::value) + return Number{}; + else + return Number<1>{}; + }, + Number{}); + + using DeviceReduceInstance = DeviceReduceThreadWiseMultiD< + ReduceDataType, // InDataType + DsDataType, // DsDatatype + GemmAccDataType, // AccDataType + CDataType, // OutDataType + 3, // Rank + 1, // NumReduceDim + ReduceAdd, + PassThrough, + OutElementwiseOperation, + 256, // BlockSize_ + CShuffleBlockTransferScalarPerVector_NPerBlock, // MThreadSliceSize_ + 1, // KThreadSliceSize_ + 0, // InSrcVectorDim_ + CShuffleBlockTransferScalarPerVector_NPerBlock, // InSrcVectorSize_ + CShuffleBlockTransferScalarPerVector_NPerBlock, // OutDstVectorSize_ + decltype(DsVectorLengthSequence)>; + + struct Invoker : public BaseInvoker + { + float RunReduce(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + static constexpr index_t NumInDim = 3; + static constexpr index_t NumOutDim = 2; + + ::std::array in_lengths = {arg.KBatch, arg.M, arg.N}; + ::std::array out_lengths = {arg.M, arg.N}; + + ::std::array in_strides; + ::std::array out_strides; + if constexpr(is_same::value) + { + in_strides = {arg.M * arg.N, arg.N, 1}; + out_strides = {arg.N, 1}; + } + else + { + in_strides = {arg.M * arg.N, 1, arg.M}; + out_strides = {1, arg.M}; + } + + ::std::array reduce_dims{0}; + + ::std::array<::std::array, NumDTensor> DsLengths; + ::std::array<::std::array, NumDTensor> DsStrides; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + DsLengths[i] = out_lengths; + + using DLayout = ::std::__remove_cvref_t>; + if constexpr(is_same::value) + { + DsStrides[i] = {arg.StrideDs[i], 1}; + } + else + { + DsStrides[i] = {1, arg.StrideDs[i]}; + } + }); + + auto reduce = DeviceReduceInstance{}; + + auto argument_ptr = reduce.MakeArgumentPointer(in_lengths, + in_strides, + DsLengths, + DsStrides, + out_lengths, + out_strides, + reduce_dims, + arg.p_workspace_, + arg.p_ds, + arg.p_c_grid, + PassThrough{}, + OutElementwiseOperation{}); + + auto invoker_ptr = reduce.MakeInvokerPointer(); + + float ave_time = 0; + + if(reduce.IsSupportedArgument(argument_ptr.get())) + { + ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config); + } + else + { + throw ::std::runtime_error( + "The runtime parameters are not supported by the device instance."); + } + + return ave_time; + } + + float Run(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{}) + { + auto arg = *dynamic_cast(&arg_); + + // workspace required when doing two-kernel reduce or Ds present + const bool need_workspace = !(!(arg.IsReduceAdd() || NumDTensor > 0) && + is_same::value); + if(need_workspace) + { + if(arg.p_workspace_ == nullptr) + { + throw ::std::runtime_error("using reduce, but empty workspace!"); + } + arg.p_e_grid = reinterpret_cast(arg.p_workspace_); + } + + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw ::std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + ::std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + const auto kernel = + ::ck::kernel_gemm_wmma_cshuffle_v3; + ave_time = launch_and_time_kernel( + stream_config, kernel, ::dim3(gdx, gdy, gdz), ::dim3(BlockSize), 0, arg); + } + else + { + const auto kernel = + ::ck::kernel_gemm_wmma_cshuffle_v3; + ave_time = launch_and_time_kernel( + stream_config, kernel, ::dim3(gdx, gdy, gdz), ::dim3(BlockSize), 0, arg); + } + + if(need_workspace) + { + ave_time += RunReduce(arg_, stream_config); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_wmma_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity( + *dynamic_cast(&arg)); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return GridwiseGemm::CalculateGridSize(M, N, KBatch); + } + + static constexpr index_t GetBlockSize() { return BlockSize; } + + static size_t GetSharedMemoryNumberOfByte() + { + return GridwiseGemm::GetSharedMemoryNumberOfByte(); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const ::std::array p_ds, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + const ::std::array stride_ds, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_c, + M, + N, + K, + StrideA, + StrideB, + stride_ds, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + ::std::unique_ptr MakeInvokerPointer() override + { + return ::std::make_unique(Invoker{}); + } + + // Polymorphic interfaces + ::std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + ::std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + ::std::array DsStrides, + index_t StrideC, + index_t KSplit, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return ::std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + DsStrides, + StrideC, + KSplit, + a_element_op, + b_element_op, + c_element_op); + } + + ::std::string GetTypeString() const override + { + auto str = ::std::stringstream(); + + auto BlkGemmPipelineSchedulerToString = [](BlockGemmPipelineScheduler s) { + switch(s) + { + case BlockGemmPipelineScheduler::Intrawave: return ::std::string("Intrawave"); + case BlockGemmPipelineScheduler::Interwave: return ::std::string("Interwave"); + } + return ::std::string("?"); + }; + + auto BlkGemmPipelineVersionToString = [](BlockGemmPipelineVersion v) { + switch(v) + { + case BlockGemmPipelineVersion::v1: return ::std::string("v1"); + case BlockGemmPipelineVersion::v2: return ::std::string("v2"); + case BlockGemmPipelineVersion::v3: return ::std::string("v3"); + case BlockGemmPipelineVersion::v4: return ::std::string("v4"); + case BlockGemmPipelineVersion::v5: return ::std::string("v5"); + } + return ::std::string("v?"); + }; + + // clang-format off + str << "DeviceGemmWmmaUniversalReduce" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << ::std::string(ALayout::name)[0] + << ::std::string(BLayout::name)[0] + << ::std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"<(p_arg); + + // Need workspace if using split-K or have D tensors + if(!(!(arg.IsReduceAdd() || NumDTensor > 0) && is_same::value)) + { + return arg.M * arg.N * arg.KBatch * sizeof(ReduceDataType); + } + + return 0; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index f779909e87..b226730a09 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -3,6 +3,11 @@ #pragma once +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#include +#include +#endif + #include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" @@ -1049,6 +1054,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base { if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop + << ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages + << ") for pipeline version != v1." << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp index 7a3190e6f4..86bae7655b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp @@ -70,9 +70,9 @@ struct WmmaTraitsBase static constexpr index_t kRepeat = 1; static constexpr index_t kAMLane = 16; static constexpr index_t kBNLane = 16; - static constexpr index_t kABK0PerLane = 2; + static constexpr index_t kABK0PerLane = 1; static constexpr index_t kABKLane = 2; - static constexpr index_t kABK1PerLane = 4; + static constexpr index_t kABK1PerLane = 8; static constexpr index_t kCMLane = 2; static constexpr index_t kCNLane = 16; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp index 7727489e51..430a4e52f4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,6 +8,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -20,6 +21,7 @@ namespace instance { using DsLayout = ck::Tuple<>; using DsDataType = ck::Tuple<>; +#ifdef CK_USE_XDL #ifdef CK_ENABLE_FP16 void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances( std::vector>>& instances); +#endif +#endif +#ifdef CK_USE_WMMA +#if defined(CK_ENABLE_FP16) +void add_device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& instances); +#endif + +#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8)) +void add_device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& instances); +#endif + +#if defined(CK_ENABLE_BF16) +void add_device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& instances); +#endif #endif template && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances( op_ptrs); add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( @@ -395,6 +445,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( op_ptrs); +#endif + +#ifdef CK_USE_WMMA + add_device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances( + op_ptrs); +#endif } } #endif @@ -406,6 +462,7 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances( op_ptrs); add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances( @@ -420,6 +477,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( op_ptrs); +#endif + +#ifdef CK_USE_WMMA + add_device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances( + op_ptrs); +#endif } } #endif @@ -430,6 +493,7 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( op_ptrs); add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances( @@ -444,6 +508,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( op_ptrs); +#endif + +#ifdef CK_USE_WMMA + add_device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + op_ptrs); +#endif } } #endif diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/CMakeLists.txt index 07263528b9..142ace2e42 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/CMakeLists.txt @@ -1,6 +1,7 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_UNIVERSAL_REDUCE_INSTANCES) +# XDL instances list(APPEND GEMM_UNIVERSAL_REDUCE_INSTANCES device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp @@ -30,4 +31,11 @@ list(APPEND GEMM_UNIVERSAL_REDUCE_INSTANCES device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp ) +# WMMA instances +list(APPEND GEMM_UNIVERSAL_REDUCE_INSTANCES + device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp + ) + add_instance_library(device_gemm_universal_reduce_instance ${GEMM_UNIVERSAL_REDUCE_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp new file mode 100644 index 0000000000..ee94046b8d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +using DsLayout = ck::Tuple<>; +using DsDataType = ck::Tuple<>; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template , + typename DsDataType = ck::Tuple<>> +using device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPerWmma|NPerWmma|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Reduce| + //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | | | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MRepeatPer|NRepeatPer| _MBlock_MRepeatPerShuffle_MWaveM| ScalarPerVector| Pipeline| Pipeline| DataType| + //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Shuffle | Shuffle | PerShuffle_NBlock_NRepeatPerShuffle| _NPerBlock | Scheduler| Version| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NWaveNPerRepeat | | | | | + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..20d88e4740 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using Row = tensor_layout::gemm::RowMajor; +using PassThrough = element_wise::PassThrough; + +void add_device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& instances) +{ + if(ck::is_gfx12_supported()) + { + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn.hpp new file mode 100644 index 0000000000..3ddeec3c02 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = int8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +using DsLayout = ck::Tuple<>; +using DsDataType = ck::Tuple<>; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template , + typename DsDataType = ck::Tuple<>> +using device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPerWmma|NPerWmma|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Reduce| + //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | | | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MRepeatPer|NRepeatPer| _MBlock_MRepeatPerShuffle_MWaveM| ScalarPerVector| Pipeline| Pipeline| DataType| + //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Shuffle | Shuffle | PerShuffle_NBlock_NRepeatPerShuffle| _NPerBlock | Scheduler| Version| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NWaveNPerRepeat | | | | | + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 4, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 4, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 4, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..52589a258f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = int8_t; +using BF16 = bhalf_t; +using Row = tensor_layout::gemm::RowMajor; +using PassThrough = element_wise::PassThrough; + +void add_device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& instances) +{ + if(ck::is_gfx12_supported()) + { + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp new file mode 100644 index 0000000000..564b81496d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +using DsLayout = ck::Tuple<>; +using DsDataType = ck::Tuple<>; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template , + typename DsDataType = ck::Tuple<>> +using device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPerWmma|NPerWmma|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Reduce| + //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | | | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MRepeatPer|NRepeatPer| _MBlock_MRepeatPerShuffle_MWaveM| ScalarPerVector| Pipeline| Pipeline| DataType| + //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Shuffle | Shuffle | PerShuffle_NBlock_NRepeatPerShuffle| _NPerBlock | Scheduler| Version| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NWaveNPerRepeat | | | | | + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..3663ee6529 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using Row = tensor_layout::gemm::RowMajor; +using PassThrough = element_wise::PassThrough; +using Add = element_wise::Add; + +using DsLayout_F16 = ck::Tuple<>; +using DsDataType_F16 = ck::Tuple<>; + +void add_device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& instances) +{ + if(ck::is_gfx12_supported()) + { + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp index a0ee6a6674..32d2b38def 100644 --- a/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp @@ -10,6 +10,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp" @@ -86,10 +87,21 @@ bool profile_gemm_universal_reduce_impl(int do_verification, switch(init_method) { - case 0: break; + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + a_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); break; default: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index ce8e652339..5538307232 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -68,7 +68,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_gemm_splitk.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp) list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp) - list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp) list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp) list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp) list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp) @@ -90,6 +89,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND PROFILER_OPS profile_gemm_universal.cpp) list(APPEND PROFILER_OPS profile_batched_gemm.cpp) list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp) + list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) @@ -185,7 +185,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance) - list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_instance) @@ -221,6 +220,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND DEVICE_INSTANCES device_gemm_universal_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_instance) list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance) + list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 947d5136be..c292400878 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -248,6 +248,7 @@ add_subdirectory(gemm_universal) add_subdirectory(gemm_b_scale) add_subdirectory(gemm_universal_streamk) add_subdirectory(gemm_reduce) +add_subdirectory(gemm_universal_reduce) add_subdirectory(batched_gemm) add_subdirectory(batched_gemm_reduce) add_subdirectory(batched_gemm_gemm) diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt b/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt index 37774f7643..64672e200b 100644 --- a/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt +++ b/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt @@ -18,7 +18,7 @@ function(create_tile_add_rmsnorm2d_rdquant_fwd SUFFIX) set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) endfunction() -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") create_tile_add_rmsnorm2d_rdquant_fwd("fp16") create_tile_add_rmsnorm2d_rdquant_fwd("bf16") else() diff --git a/test/ck_tile/batched_gemm/CMakeLists.txt b/test/ck_tile/batched_gemm/CMakeLists.txt index 532ead1124..9bcbc7352e 100644 --- a/test/ck_tile/batched_gemm/CMakeLists.txt +++ b/test/ck_tile/batched_gemm/CMakeLists.txt @@ -1,4 +1,3 @@ -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_batched_gemm test_batched_gemm.cpp) endif() diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index f634e508e3..1e2ea45b9e 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -27,21 +27,41 @@ class TestCkTileBatchedGemm : public ::testing::Test using DsLayout = ck_tile::tuple<>; using DsDataType = ck_tile::tuple<>; - template + struct GemmWarpConfig_Mfma + { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + }; + + struct GemmWarpConfig_Wmma + { + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + }; + + template void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) { - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 64; + constexpr ck_tile::index_t M_Tile = GemmWarpConfig::M_Tile; + constexpr ck_tile::index_t N_Tile = GemmWarpConfig::N_Tile; + constexpr ck_tile::index_t K_Tile = GemmWarpConfig::K_Tile; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; + constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; constexpr bool DoubleSmemBuffer = false; @@ -255,9 +275,13 @@ class TestCkTileBatchedGemm : public ::testing::Test BatchStrideB, BatchStrideC, BatchCount}; - - invoke_batched_gemm(args, - ck_tile::stream_config{nullptr, false}); +#if CK_TILE_USE_WMMA + invoke_batched_gemm( + args, ck_tile::stream_config{nullptr, false}); +#else + invoke_batched_gemm( + args, ck_tile::stream_config{nullptr, false}); +#endif std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideC =" << StrideC diff --git a/test/ck_tile/batched_transpose/CMakeLists.txt b/test/ck_tile/batched_transpose/CMakeLists.txt index 111b7c2bed..fb45caf044 100644 --- a/test/ck_tile/batched_transpose/CMakeLists.txt +++ b/test/ck_tile/batched_transpose/CMakeLists.txt @@ -1,5 +1,4 @@ -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx950") add_gtest_executable(test_ck_tile_batched_transpose test_batched_transpose.cpp) set_property(TARGET test_ck_tile_batched_transpose PROPERTY CXX_STANDARD 20) else() diff --git a/test/ck_tile/container/CMakeLists.txt b/test/ck_tile/container/CMakeLists.txt index 50670c83e4..f13f0dbedf 100644 --- a/test/ck_tile/container/CMakeLists.txt +++ b/test/ck_tile/container/CMakeLists.txt @@ -1,4 +1,4 @@ -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_tuple_apply test_tuple_apply.cpp) if(result EQUAL 0) target_link_libraries(test_ck_tile_tuple_apply PRIVATE utility) diff --git a/test/ck_tile/data_type/CMakeLists.txt b/test/ck_tile/data_type/CMakeLists.txt index 384fd3c1c4..a5713ac55c 100644 --- a/test/ck_tile/data_type/CMakeLists.txt +++ b/test/ck_tile/data_type/CMakeLists.txt @@ -1,4 +1,4 @@ -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp) endif() if(GPU_TARGETS MATCHES "gfx95") diff --git a/test/ck_tile/elementwise/CMakeLists.txt b/test/ck_tile/elementwise/CMakeLists.txt index d22a30ff56..5fca0eb801 100644 --- a/test/ck_tile/elementwise/CMakeLists.txt +++ b/test/ck_tile/elementwise/CMakeLists.txt @@ -1,4 +1,4 @@ -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_elementwise_1d test_elementwise_1d.cpp) if(result EQUAL 0) target_link_libraries(test_ck_tile_elementwise_1d PRIVATE utility) diff --git a/test/ck_tile/elementwise/test_elementwise_1d.cpp b/test/ck_tile/elementwise/test_elementwise_1d.cpp index 3ce6e78d1d..2eb2b506e8 100644 --- a/test/ck_tile/elementwise/test_elementwise_1d.cpp +++ b/test/ck_tile/elementwise/test_elementwise_1d.cpp @@ -106,7 +106,7 @@ class TestCkTileElementwise : public ::testing::Test ck_tile::index_t grid_size = (total_m_elements + TestElementWiseShape::kBlockM - 1) / TestElementWiseShape::kBlockM; dim3 grid(grid_size, 1, 1); - dim3 block(TestElementWiseShape::kBlockSize, 1, 1); + dim3 block = dim3(ew_kernel.BlockSize()); constexpr ck_tile::index_t kBlockPerCu = 1; ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log diff --git a/test/ck_tile/fmha/test_fmha_fwd.inc b/test/ck_tile/fmha/test_fmha_fwd.inc index 9ff5b442b4..f02ef1e55e 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.inc +++ b/test/ck_tile/fmha/test_fmha_fwd.inc @@ -401,7 +401,7 @@ TEST_P(PagedKV, Test) 0, // scale_s 0, // logits_soft_cap is_v_rowmajor, // is_v_rowmajor - def_lse, // lse + false, // lse page_block_size, // page_block_size false, // use_cache_batch_idx "n", // bias_str diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 5d34943e0d..44e2433060 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -12,16 +12,16 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -enable-noalias-to-md-conversion=0 ) -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") - add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp) - add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp) - add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp) - - target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) - +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") + add_test_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_universal_int8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_universal_pk_int4 test_gemm_pipeline_universal_pk_int4.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_universal_pk_int4 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +else() + message(DEBUG "Skipping ck_tile_gemm tests for current target") +endif() +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_test_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp) target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp) @@ -30,37 +30,47 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp) target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - - add_test_executable(test_ck_tile_gemm_pipeline_universal_int8 test_gemm_pipeline_universal_int8.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_universal_int8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - add_test_executable(test_ck_tile_gemm_pipeline_universal_pk_int4 test_gemm_pipeline_universal_pk_int4.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_universal_pk_int4 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - -elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") - # On Radeon devices, build the WMMA version instead - add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp) - add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp) - add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) else() message(DEBUG "Skipping ck_tile_gemm tests for current target") endif() -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx90a") - add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12") add_test_executable(test_ck_tile_gemm_pipeline_universal_fp16 test_gemm_pipeline_universal_fp16.cpp) target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE --save-temps -Wno-gnu-line-marker) add_test_executable(test_ck_tile_gemm_pipeline_universal_bf16 test_gemm_pipeline_universal_bf16.cpp) target_compile_options(test_ck_tile_gemm_pipeline_universal_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_fp16 test_gemm_pipeline_basic_fp16.cpp) target_compile_options(test_ck_tile_gemm_pipeline_basic_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_bf16 test_gemm_pipeline_basic_bf16.cpp) target_compile_options(test_ck_tile_gemm_pipeline_basic_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -elseif(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") - add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +else() + message(DEBUG "Skipping ck_tile_gemm tests for current target ") +endif() + +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12") + if(GPU_TARGETS MATCHES "gfx94|gfx95") + add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + endif() + + if(GPU_TARGETS MATCHES "gfx11|gfx12") + # On Radeon devices, build the WMMA version instead + add_gtest_executable(test_ck_tile_gemm_pipeline_mem_wmma test_gemm_pipeline_mem_wmma.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_compv3_wmma test_gemm_pipeline_compv3_wmma.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_compv4_wmma test_gemm_pipeline_compv4_wmma.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_persistent_wmma test_gemm_pipeline_persistent_wmma.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_mem_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv3_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv4_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_persistent_wmma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + endif() +else() + message(DEBUG "Skipping ck_tile_gemm tests for current target test_ck_tile_gemm_pipeline") endif() diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc index 1fdf26f01c..706035cabc 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc @@ -13,6 +13,28 @@ #include "test_gemm_pipeline_smoke_util.hpp" #include "test_gemm_pipeline_smoke_run_test.inc" +struct GemmConfig_Mfma : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +}; + +struct GemmConfig_Wmma : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +}; + template , @@ -130,7 +152,10 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) } } -template +template bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, ck_tile::ArgParser& arg_parser) @@ -142,12 +167,12 @@ bool run_gemm_test_prec_type(std::string a_layout, { if(a_layout == "R" && b_layout == "C") { - return run_gemm_test_with_layouts( + return run_gemm_test_with_layouts( arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_test_with_layouts( + return run_gemm_test_with_layouts( arg_parser, Col{}, Col{}, Row{}); } else @@ -160,22 +185,22 @@ bool run_gemm_test_prec_type(std::string a_layout, { if(a_layout == "R" && b_layout == "C") { - return run_gemm_test_with_layouts( + return run_gemm_test_with_layouts( arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "R" && b_layout == "R") { - return run_gemm_test_with_layouts( + return run_gemm_test_with_layouts( arg_parser, Row{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { - return run_gemm_test_with_layouts( + return run_gemm_test_with_layouts( arg_parser, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_test_with_layouts( + return run_gemm_test_with_layouts( arg_parser, Col{}, Col{}, Row{}); } else @@ -185,7 +210,7 @@ bool run_gemm_test_prec_type(std::string a_layout, } } -template +template bool run_gemm_test(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -195,7 +220,8 @@ bool run_gemm_test(int argc, char* argv[]) std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); - return run_gemm_test_prec_type(a_layout, b_layout, arg_parser); + return run_gemm_test_prec_type( + a_layout, b_layout, arg_parser); } template @@ -255,8 +281,15 @@ int run_gemm_combinations() // Call the function with the current configuration try { - is_success = run_gemm_test(ARG_COUNT, argv) && +#if CK_TILE_USE_WMMA + is_success = run_gemm_test( + ARG_COUNT, argv) && is_success; +#else + is_success = run_gemm_test( + ARG_COUNT, argv) && + is_success; +#endif } catch(const ArgumentsNotSupportedException& e) { diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp index f64d3e092b..52f6ea7026 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp @@ -220,6 +220,27 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; +template +struct GemmConfigComputeV3_WMMA : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + template struct GemmTypeConfig; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc index fd50596f2f..dfee45cdfd 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc @@ -325,6 +325,13 @@ int run_gemm_combinations() // Call the function with the current configuration try { +#if CK_TILE_USE_WMMA + is_success = run_gemm_test, + APrecType, + BPrecType, + CPrecType>(ARG_COUNT, argv) && + is_success; +#else is_success = run_gemm_test, APrecType, BPrecType, @@ -335,6 +342,7 @@ int run_gemm_combinations() BPrecType, CPrecType>(ARG_COUNT, argv) && is_success; +#endif } catch(const ArgumentsNotSupportedException& e) { diff --git a/test/ck_tile/gemm_multi_d/CMakeLists.txt b/test/ck_tile/gemm_multi_d/CMakeLists.txt index c9d53e53e2..143fb9dc40 100644 --- a/test/ck_tile/gemm_multi_d/CMakeLists.txt +++ b/test/ck_tile/gemm_multi_d/CMakeLists.txt @@ -1,10 +1,9 @@ -# Currently ck_tile is only built on gfx9 set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_gtest_executable(test_gemm_multi_d_cshuffle test_gemm_multi_d_cshuffle.cpp) add_gtest_executable(test_gemm_multi_d_default2d test_gemm_multi_d_default2d.cpp) target_compile_definitions(test_gemm_multi_d_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp index 8399bc7ee3..f0050c15d5 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -86,7 +86,28 @@ class TestCkTileGemmMultiD : public ::testing::Test using DsLayout = ck_tile::tuple; using DsDataType = ck_tile::tuple; - template & args, const ck_tile::stream_config& s) { - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 64; + constexpr ck_tile::index_t M_Tile = GemmWarpConfig::M_Tile; + constexpr ck_tile::index_t N_Tile = GemmWarpConfig::N_Tile; + constexpr ck_tile::index_t K_Tile = GemmWarpConfig::K_Tile; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; + constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; constexpr bool DoubleSmemBuffer = false; @@ -359,8 +380,9 @@ class TestCkTileGemmMultiD : public ::testing::Test StrideB, stridesDs, StrideE}); - - invoke_gemm_multi_d(args, ck_tile::stream_config{nullptr, false}); +#else + invoke_gemm_multi_d(args, ck_tile::stream_config{nullptr, false}); +#endif std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << StrideA << " StrideB =" << StrideB << " StrideE =" << StrideE diff --git a/test/ck_tile/gemm_weight_preshuffle/CMakeLists.txt b/test/ck_tile/gemm_weight_preshuffle/CMakeLists.txt index 4b9e6049e3..90803bd9d5 100644 --- a/test/ck_tile/gemm_weight_preshuffle/CMakeLists.txt +++ b/test/ck_tile/gemm_weight_preshuffle/CMakeLists.txt @@ -12,7 +12,7 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -enable-noalias-to-md-conversion=0 ) -if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") add_gtest_executable(test_ck_tile_gemm_pipeline_wp test_gemm_pipeline_wp.cpp) target_compile_options(test_ck_tile_gemm_pipeline_wp PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp index f66f3cb0aa..b1521fc35a 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp @@ -31,8 +31,10 @@ using F8Types = std::tuple, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffle>, - F8Types + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffle> +#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 + , F8Types +#endif >; // clang-format on diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 5d52f15696..42d0149498 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -63,6 +63,23 @@ struct config static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32; }; + +template +struct config_wmma +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(Datatype); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; +}; + template class TestCkTileGemmPipeline : public ::testing::Test { @@ -79,13 +96,12 @@ class TestCkTileGemmPipeline : public ::testing::Test using DsLayout = ck_tile::tuple<>; using DsDataType = ck_tile::tuple<>; - using GemmConfig = config; static constexpr bool Persistent = ck_tile::tuple_element_or_default_t::value; // TODO: expose tile size through test t-param ? - template + template void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // TODO: This should be parameterized in tests @@ -253,6 +269,48 @@ class TestCkTileGemmPipeline : public ::testing::Test k_batches_ = {1}; } + template + auto shuffle_b(const ck_tile::HostTensor& t) + { + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if(ck_tile::is_gfx12_supported()) + { + constexpr int divisor = 2; + constexpr int kABK1PerLane = 8; + constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + kABK0PerLane, + divisor, + kABK1PerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); + } + else + { + int divisor = 1; + if(ck_tile::is_gfx11_supported()) + { + divisor = 1; + } + else + { + assert(is_wave32() == false); + divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + } + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, + divisor, + GemmConfig::K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + } template void Run(const int M, const int N, @@ -263,11 +321,17 @@ class TestCkTileGemmPipeline : public ::testing::Test { for(auto kb : k_batches_) { - RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); +#if CK_TILE_USE_WMMA + RunSingle, PadM, PadN, PadK, Preshuffle>( + M, N, K, StrideA, StrideB, StrideC, kb); +#else + RunSingle, PadM, PadN, PadK, Preshuffle>( + M, N, K, StrideA, StrideB, StrideC, kb); +#endif } } - template + template void RunSingle(const int M, const int N, const int K, @@ -327,16 +391,7 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view({N / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - K / GemmConfig::K_Warp_Tile, - divisor, - GemmConfig::K_Warp_Tile / divisor}); - - std::copy(b_k_n.begin(), b_k_n.end(), t_view.begin()); - ck_tile::HostTensor b_shuffle_host = - ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n); a_m_k_dev_buf.ToDevice(a_m_k.data()); b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); @@ -354,7 +409,8 @@ class TestCkTileGemmPipeline : public ::testing::Test stride_B, stride_C}; - invoke_gemm(args, ck_tile::stream_config{nullptr, false}); + invoke_gemm( + args, ck_tile::stream_config{nullptr, false}); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; diff --git a/test/ck_tile/grouped_gemm/CMakeLists.txt b/test/ck_tile/grouped_gemm/CMakeLists.txt index f4845847f1..4fd5c82ae9 100644 --- a/test/ck_tile/grouped_gemm/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ # Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_grouped_gemm test_grouped_gemm.cpp) endif() diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index 5aca02a433..6893318ea2 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -31,7 +31,7 @@ class TestCkTileGroupedGemm : public ::testing::Test using PersistentType = std::tuple_element_t<7, Tuple>; static constexpr bool Persistent = PersistentType::value; - struct GroupedGemKernelParam + struct GroupedGemKernelParam_Mfma { static const bool kPadM = false; static const bool kPadN = false; @@ -51,13 +51,24 @@ class TestCkTileGroupedGemm : public ::testing::Test static const ck_tile::index_t K_Warp_Tile = 16; }; + struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma + { + static const ck_tile::index_t M_Tile = 128; + static const ck_tile::index_t N_Tile = 128; + static const ck_tile::index_t K_Tile = 64; + + static const ck_tile::index_t M_Warp_Tile = 16; + static const ck_tile::index_t N_Warp_Tile = 16; + static const ck_tile::index_t K_Warp_Tile = 16; + }; + using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; std::size_t get_workspace_size(const std::vector& gemm_descs) { return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } - template + template void invoke_grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* kargs_ptr) @@ -200,7 +211,7 @@ class TestCkTileGroupedGemm : public ::testing::Test BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } - template + template void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, void* kargs_ptr, @@ -460,15 +471,27 @@ class TestCkTileGroupedGemm : public ::testing::Test kargs.size() * sizeof(ck_tile::GemmTransKernelArg), hipMemcpyHostToDevice, stream.stream_id_)); - invoke_grouped_gemm_persistent( +#if CK_TILE_USE_WMMA + invoke_grouped_gemm_persistent( stream, group_count, kargs_ptr, splitk); +#else + invoke_grouped_gemm_persistent( + stream, group_count, kargs_ptr, splitk); +#endif } else { - invoke_grouped_gemm( +#if CK_TILE_USE_WMMA + invoke_grouped_gemm( gemm_descs, ck_tile::stream_config{nullptr, false, 1}, gemm_workspace.GetDeviceBuffer()); +#else + invoke_grouped_gemm( + gemm_descs, + ck_tile::stream_config{nullptr, false, 1}, + gemm_workspace.GetDeviceBuffer()); +#endif } // Copy results back to host for validation diff --git a/test/ck_tile/image_to_column/CMakeLists.txt b/test/ck_tile/image_to_column/CMakeLists.txt index 247358dd4d..8873a846fc 100644 --- a/test/ck_tile/image_to_column/CMakeLists.txt +++ b/test/ck_tile/image_to_column/CMakeLists.txt @@ -1,4 +1,3 @@ -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_tile_image_to_column test_tile_image_to_column.cpp) endif() diff --git a/test/ck_tile/layernorm2d/CMakeLists.txt b/test/ck_tile/layernorm2d/CMakeLists.txt index c909d6cf40..e924f39e7a 100644 --- a/test/ck_tile/layernorm2d/CMakeLists.txt +++ b/test/ck_tile/layernorm2d/CMakeLists.txt @@ -14,7 +14,7 @@ function(create_tile_layernorm2d_fwd SUFFIX) target_compile_options(${TEST_CK_TILE_LAYERNORM2D_FWD} PRIVATE ${TEST_CK_TILE_LAYERNORM2D_FWD_COMPILE_OPTIONS}) endfunction() -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd") set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING "semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".") diff --git a/test/ck_tile/moe_smoothquant/CMakeLists.txt b/test/ck_tile/moe_smoothquant/CMakeLists.txt index b6c8a395b6..019e87323f 100644 --- a/test/ck_tile/moe_smoothquant/CMakeLists.txt +++ b/test/ck_tile/moe_smoothquant/CMakeLists.txt @@ -1,5 +1,4 @@ -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") function (add_moe_smoothquant_test TARGET_NAME MAIN_SRC) message(DEBUG "adding ${TARGET_NAME}") add_gtest_executable(${TARGET_NAME} ${MAIN_SRC}) diff --git a/test/ck_tile/moe_sorting/CMakeLists.txt b/test/ck_tile/moe_sorting/CMakeLists.txt index 5abc7df5a9..48d8e1392f 100644 --- a/test/ck_tile/moe_sorting/CMakeLists.txt +++ b/test/ck_tile/moe_sorting/CMakeLists.txt @@ -1,5 +1,5 @@ -# Currently ck_tile is only built on gfx90a, gfx942 and gfx950 -if(GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx950" OR GPU_TARGETS MATCHES "gfx90a") +# Currently ck_tile is only built on gfx90a, gfx942, gfx950, gfx11 and gfx12 +if(GPU_TARGETS MATCHES "gfx942|gfx950|gfx90a|gfx11|gfx12") function(add_moe_sorting_test EXECUTABLE USE_2D_BUF) add_gtest_executable(${EXECUTABLE} test_moe_sorting.cpp moe_sorting_api.cpp) diff --git a/test/ck_tile/permute/CMakeLists.txt b/test/ck_tile/permute/CMakeLists.txt index 4256ad8de1..8574813be3 100644 --- a/test/ck_tile/permute/CMakeLists.txt +++ b/test/ck_tile/permute/CMakeLists.txt @@ -1,5 +1,4 @@ -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") function(add_permute_test TARGET_NAME MAIN_SRC) add_gtest_executable(${TARGET_NAME} ${MAIN_SRC}) diff --git a/test/ck_tile/permute/test_permute_util.hpp b/test/ck_tile/permute/test_permute_util.hpp index 5494749541..2028f56bb8 100644 --- a/test/ck_tile/permute/test_permute_util.hpp +++ b/test/ck_tile/permute/test_permute_util.hpp @@ -17,9 +17,11 @@ #include #include +#if !CK_TILE_USE_WMMA #ifdef PERMUTE_USE_ALTERNATIVE_IMPL #include "alternative_impl/matrix_core_swizzle.hpp" #endif +#endif namespace detail { template @@ -193,6 +195,7 @@ class TestCkTilePermute : public ::testing::Test return permute(a, stream_config); }; +#if !CK_TILE_USE_WMMA #ifdef PERMUTE_USE_ALTERNATIVE_IMPL // batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 if((perm == std::string("0,1,4,2,5,3,6") || perm == std::string("0,1,2,4,5,3,6") || @@ -278,6 +281,7 @@ class TestCkTilePermute : public ::testing::Test } } else +#endif #endif { run_permute(); diff --git a/test/ck_tile/reduce/CMakeLists.txt b/test/ck_tile/reduce/CMakeLists.txt index 052669e20a..0ba5974f6c 100644 --- a/test/ck_tile/reduce/CMakeLists.txt +++ b/test/ck_tile/reduce/CMakeLists.txt @@ -1,4 +1,4 @@ -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_reduce2d test_reduce2d.cpp) if(result EQUAL 0) target_link_libraries(test_ck_tile_reduce2d PRIVATE utility) diff --git a/test/ck_tile/reduce/test_reduce2d.cpp b/test/ck_tile/reduce/test_reduce2d.cpp index ff807e52c9..ded0406797 100644 --- a/test/ck_tile/reduce/test_reduce2d.cpp +++ b/test/ck_tile/reduce/test_reduce2d.cpp @@ -59,7 +59,7 @@ class TestCkTileReduce : public ::testing::Test using Kernel = ck_tile::Reduce; // Launch configuration - constexpr ck_tile::index_t kBlockSize = 256; + const ck_tile::index_t kBlockSize = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; ck_tile::index_t kGridSize = diff --git a/test/ck_tile/rmsnorm2d/CMakeLists.txt b/test/ck_tile/rmsnorm2d/CMakeLists.txt index 5a73b0914c..c60d73aafd 100644 --- a/test/ck_tile/rmsnorm2d/CMakeLists.txt +++ b/test/ck_tile/rmsnorm2d/CMakeLists.txt @@ -14,7 +14,7 @@ function(create_tile_rmsnorm2d_fwd SUFFIX) target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) endfunction() -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") set(RMSNORM2D_FWD_KNOWN_APIS "fwd;bwd") set(RMSNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING "semicolon-separated list of APIs to generate (${RMSNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".") diff --git a/test/ck_tile/smoothquant/CMakeLists.txt b/test/ck_tile/smoothquant/CMakeLists.txt index 548fc03a41..381923803f 100644 --- a/test/ck_tile/smoothquant/CMakeLists.txt +++ b/test/ck_tile/smoothquant/CMakeLists.txt @@ -1,5 +1,4 @@ -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") function (add_smoothquant_test TARGET_NAME MAIN_SRC) message(DEBUG "adding ${TARGET_NAME}") diff --git a/test/ck_tile/topk_softmax/CMakeLists.txt b/test/ck_tile/topk_softmax/CMakeLists.txt index 046eaf6649..cd524eca01 100644 --- a/test/ck_tile/topk_softmax/CMakeLists.txt +++ b/test/ck_tile/topk_softmax/CMakeLists.txt @@ -10,8 +10,7 @@ function(add_tile_topk_softmax_test SUFFIX) target_compile_options(${TEST_NAME} PRIVATE ${TEST_TOPK_SOFTMAX_COMPILE_OPTIONS}) endfunction() -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_tile_topk_softmax_test(fp16) add_tile_topk_softmax_test(bf16) else() diff --git a/test/gemm_universal_reduce/CMakeLists.txt b/test/gemm_universal_reduce/CMakeLists.txt new file mode 100644 index 0000000000..dab9de44c0 --- /dev/null +++ b/test/gemm_universal_reduce/CMakeLists.txt @@ -0,0 +1,14 @@ +add_gtest_executable(test_gemm_universal_reduce_bf16_wmma test_gemm_universal_reduce_bf16_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_reduce_bf16_wmma PRIVATE utility device_gemm_universal_reduce_instance) +endif() + +add_gtest_executable(test_gemm_universal_reduce_fp16_wmma test_gemm_universal_reduce_fp16_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_reduce_fp16_wmma PRIVATE utility device_gemm_universal_reduce_instance) +endif() + +add_gtest_executable(test_gemm_universal_reduce_bf16A_i8_wmma test_gemm_universal_reduce_bf16A_i8_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_reduce_bf16A_i8_wmma PRIVATE utility device_gemm_universal_reduce_instance) +endif() diff --git a/test/gemm_universal_reduce/test_gemm_universal_reduce_bf16A_i8_wmma.cpp b/test/gemm_universal_reduce/test_gemm_universal_reduce_bf16A_i8_wmma.cpp new file mode 100644 index 0000000000..ec4c0dc784 --- /dev/null +++ b/test/gemm_universal_reduce/test_gemm_universal_reduce_bf16A_i8_wmma.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "profiler/profile_gemm_universal_reduce_impl.hpp" + +TEST(GemmUniversalReduce, BF16A_I8) +{ + using Row = ck::tensor_layout::gemm::RowMajor; + + int M = 512; + int N = 256; + int K = 128; + int KBatch = 1; + + bool pass = true; + + pass = pass && ck::profiler::profile_gemm_universal_reduce_impl, + float, + ck::bhalf_t, + Row, + Row, + ck::Tuple<>, + Row>( + true, 3, false, true, M, N, K, K, N, N, KBatch, 1, 10); + EXPECT_TRUE(pass); +} diff --git a/test/gemm_universal_reduce/test_gemm_universal_reduce_bf16_wmma.cpp b/test/gemm_universal_reduce/test_gemm_universal_reduce_bf16_wmma.cpp new file mode 100644 index 0000000000..cbc7860fd9 --- /dev/null +++ b/test/gemm_universal_reduce/test_gemm_universal_reduce_bf16_wmma.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "profiler/profile_gemm_universal_reduce_impl.hpp" + +TEST(GemmUniversalReduce, BF16) +{ + using Row = ck::tensor_layout::gemm::RowMajor; + + int M = 512; + int N = 256; + int K = 128; + int KBatch = 1; + + bool pass = true; + + pass = pass && ck::profiler::profile_gemm_universal_reduce_impl, + float, + ck::bhalf_t, + Row, + Row, + ck::Tuple<>, + Row>( + true, 1, false, true, M, N, K, K, N, N, KBatch, 1, 10); + EXPECT_TRUE(pass); +} diff --git a/test/gemm_universal_reduce/test_gemm_universal_reduce_fp16_wmma.cpp b/test/gemm_universal_reduce/test_gemm_universal_reduce_fp16_wmma.cpp new file mode 100644 index 0000000000..731bee89ed --- /dev/null +++ b/test/gemm_universal_reduce/test_gemm_universal_reduce_fp16_wmma.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "profiler/profile_gemm_universal_reduce_impl.hpp" + +TEST(GemmUniversalReduce, FP16) +{ + using Row = ck::tensor_layout::gemm::RowMajor; + + int M = 512; + int N = 256; + int K = 128; + int KBatch = 1; + + bool pass = true; + + pass = pass && ck::profiler::profile_gemm_universal_reduce_impl, + float, + ck::half_t, + Row, + Row, + ck::Tuple<>, + Row>( + true, 1, false, true, M, N, K, K, N, N, KBatch, 1, 10); + EXPECT_TRUE(pass); +}