diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index 904006ba36..e0476bfaad 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -27,3 +27,16 @@ add_example_executable(example_gemm_xdl_splitk_reduce_multi_d_bf16 gemm_xdl_spli add_example_executable(example_gemm_xdl_splitk_reduce_bf16A_i8B gemm_xdl_splitk_reduce_bf16A_i8B.cpp) add_example_executable(example_gemm_xdl_splitk_reduce_bfp16 gemm_xdl_splitk_reduce_bf16.cpp) + +add_custom_target(example_splitK_gemm_wmma) +add_example_executable(example_gemm_wmma_splitk_reduce_bf16 gemm_wmma_splitk_reduce_bf16.cpp) +add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_bf16) + +add_example_executable(example_gemm_wmma_splitk_reduce_bf16A_i8B gemm_wmma_splitk_reduce_bf16A_i8B.cpp) +add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_bf16A_i8B) + +add_example_executable(example_gemm_wmma_splitk_reduce_multi_d_bf16 gemm_wmma_splitk_reduce_multi_d_bf16.cpp) +add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_multi_d_bf16) + +add_example_executable(example_gemm_wmma_splitk_reduce_multi_d_fp16 gemm_wmma_splitk_reduce_multi_d_fp16.cpp) +add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_multi_d_fp16) diff --git a/example/35_splitK_gemm/common.hpp b/example/35_splitK_gemm/common.hpp index 64fadae9e5..325cc37731 100644 --- a/example/35_splitK_gemm/common.hpp +++ b/example/35_splitK_gemm/common.hpp @@ -99,3 +99,85 @@ bool parse_cmd_args(int argc, return true; } + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} diff --git a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16.cpp b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16.cpp new file mode 100644 index 0000000000..b481483d42 --- /dev/null +++ b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ReduceDataType = ck::bhalf_t; +using D0DataType = ck::bhalf_t; +using DsDataType = ck::Tuple<>; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = CLayout; +using DsLayout = ck::Tuple<>; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceWmmaGemmInstance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 128, 32, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_wmma_splitk_reduce_example.inc" + +int main(int argc, char* argv[]) { return !run_wmma_gemm_splitk_example(argc, argv); } diff --git a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16A_i8B.cpp b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16A_i8B.cpp new file mode 100644 index 0000000000..dcf4a1652d --- /dev/null +++ b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16A_i8B.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = int8_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ReduceDataType = float; +using D0DataType = ck::bhalf_t; +using DsDataType = ck::Tuple<>; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = Row; +using DsLayout = ck::Tuple<>; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceWmmaGemmInstance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 128, 32, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_wmma_splitk_reduce_example.inc" + +int main(int argc, char* argv[]) { return !run_wmma_gemm_splitk_example(argc, argv); } diff --git a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_bf16.cpp b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_bf16.cpp new file mode 100644 index 0000000000..dab308d148 --- /dev/null +++ b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_bf16.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ReduceDataType = float; +using D0DataType = ck::bhalf_t; +using DsDataType = ck::Tuple; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = CLayout; +using DsLayout = ck::Tuple; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 128, 32, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_wmma_splitk_reduce_multi_d_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_multi_d_example(argc, argv); } diff --git a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp new file mode 100644 index 0000000000..489816559d --- /dev/null +++ b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; +using ReduceDataType = float; +using D0DataType = ck::half_t; +using DsDataType = ck::Tuple; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = CLayout; +using DsLayout = ck::Tuple; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 256, 64, + 8, 8, + 16, 16, + 4, 4, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_wmma_splitk_reduce_multi_d_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_multi_d_example(argc, argv); } diff --git a/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc b/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc index 9635993d63..0b060841bf 100644 --- a/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc +++ b/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc @@ -3,88 +3,6 @@ #pragma once -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { diff --git a/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_example.inc b/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_example.inc new file mode 100644 index 0000000000..25628ef770 --- /dev/null +++ b/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_example.inc @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +bool run_wmma_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "init method: " << config.init_method << std::endl; + std::cout << "KBatch: " << KBatch << std::endl; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // device GEMM + auto device_op = DeviceWmmaGemmInstance{}; + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + std::array{}, // empty D tensors + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + std::array{}, // empty D strides + StrideC, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + // Allocate workspace for split-K reduction if needed + size_t workspace_size = device_op.GetWorkSpaceSize(argument.get()); + DeviceMem workspace_buf(workspace_size); + std::cout << "Workspace size: " << workspace_size << " bytes" << std::endl; + if(workspace_size > 0) + { + argument->p_workspace_ = workspace_buf.GetDeviceBuffer(); + std::cout << "Allocated workspace of size: " << workspace_size << " bytes" << std::endl; + } + + if(!device_op.IsSupportedArgument(argument.get())) + { + std::cout << "The runtime argument is not supported!" << std::endl; + std::cout << "Debug info:" << std::endl; + std::cout << " M=" << M << ", N=" << N << ", K=" << K << ", KBatch=" << KBatch + << std::endl; + std::cout << " StrideA=" << StrideA << ", StrideB=" << StrideB << ", StrideC=" << StrideC + << std::endl; + return false; + } + + bool pass = true; + float ave_time = 0; + + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, cde_element_op); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, false}); + + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = ck::utils::check_err(c_m_n_device_result.mData, + c_m_n_host_result.mData, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E12 / ave_time; + + float gb_per_sec = num_btype / 1.E9 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + return pass; +} + +bool run_wmma_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_wmma_gemm(problem_size, config); +} diff --git a/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_multi_d_example.inc b/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_multi_d_example.inc new file mode 100644 index 0000000000..59996655c6 --- /dev/null +++ b/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_multi_d_example.inc @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +bool run_wmma_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto StrideD0 = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + StrideD0 = f_get_default_stride(M, N, StrideD0, D0Layout{}); + + Tensor a_m_k( + f_host_tensor_descriptor(problem_size.M, problem_size.K, problem_size.StrideA, ALayout{})); + Tensor b_k_n( + f_host_tensor_descriptor(problem_size.K, problem_size.N, problem_size.StrideB, BLayout{})); + Tensor d0_m_n( + f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, D0Layout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + d0_m_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + d0_m_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + Tensor c_m_n_host_result( + f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "init method: " << config.init_method << std::endl; + std::cout << "KBatch: " << KBatch << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + constexpr auto kNum_DTensors = DsDataType::Size(); + const std::array p_ds = {d0_m_n_device_buf.GetDeviceBuffer()}; + const std::array d_strides = {problem_size.StrideC}; + + auto argument = + gemm.MakeArgumentPointer(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + p_ds, + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + problem_size.M, + problem_size.N, + problem_size.K, + problem_size.StrideA, + problem_size.StrideB, + d_strides, + problem_size.StrideC, + problem_size.KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument.get())) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + return false; + } + + auto workspace_size = gemm.GetWorkSpaceSize(argument.get()); + DeviceMem workspace_device_buf(workspace_size); + + std::cout << "Workspace size: " << workspace_size << " bytes" << std::endl; + std::cout << "Allocated workspace of size: " << workspace_size << " bytes" << std::endl; + + if(workspace_size > 0) + { + argument->p_workspace_ = workspace_device_buf.GetDeviceBuffer(); + } + + if(config.do_verification) + { + using ReferenceGemmInstanceMultiD = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstanceMultiD{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + c_m_n_host_result.ForEach( + [&](auto& self, auto idx) { c_element_op(self(idx), self(idx), d0_m_n(idx)); }); + } + + std::cout << "init method: " << config.init_method << std::endl; + std::cout << "KBatch: " << problem_size.KBatch << std::endl; + + float ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * problem_size.M * problem_size.N * problem_size.K; + std::size_t num_btype = sizeof(ADataType) * problem_size.M * problem_size.K + + sizeof(BDataType) * problem_size.K * problem_size.N + + sizeof(CDataType) * problem_size.M * problem_size.N + + sizeof(D0DataType) * problem_size.M * problem_size.N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + if(config.do_verification) + { + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + double rtol = get_rtol(); + double atol = get_atol(); + + return ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", rtol, atol); + } + + return true; +} + +int run_gemm_splitk_multi_d_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_wmma_gemm(problem_size, config); +} diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 2e949bb1df..6b04b21e4f 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -129,5 +129,10 @@ inline bool is_gfx103_supported() ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036"; } +inline bool is_wmma_supported() +{ + return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported(); +} + } // namespace ck #endif diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp new file mode 100644 index 0000000000..3a06ea8451 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp @@ -0,0 +1,562 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1 +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + Tuple<>, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + ReduceDataType, + Tuple<>, + ReduceDataType, + AElementwiseOperation, + BElementwiseOperation, + PassThrough, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, + false>; + + struct Argument : public GridwiseGemm::Argument + { + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + const ::std::array p_ds_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + const ::std::array stride_ds_, + index_t StrideC_, + index_t KBatch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + ::std::array{}, + reinterpret_cast(p_c_grid_), + M_, + N_, + K_, + StrideA_, + StrideB_, + std::array{}, + StrideC_, + KBatch_, + a_element_op_, + b_element_op_, + PassThrough{}, + true), + p_c_grid(p_c_grid_), + c_element_op(c_element_op_), + p_ds(p_ds_), + StrideDs(stride_ds_) + { + } + + CDataType* p_c_grid; + CElementwiseOperation c_element_op; + const ::std::array p_ds; + ::std::array StrideDs; + }; + + using ReduceAdd = ck::reduce::Add; + using OutElementwiseOperation = CElementwiseOperation; + + static constexpr auto DsVectorLengthSequence = generate_sequence_v2( + [](auto i) { + using DLayout = ::std::__remove_cvref_t>; + if constexpr(is_same::value) + return Number{}; + else + return Number<1>{}; + }, + Number{}); + + using DeviceReduceInstance = DeviceReduceThreadWiseMultiD< + ReduceDataType, // InDataType + DsDataType, // DsDatatype + GemmAccDataType, // AccDataType + CDataType, // OutDataType + 3, // Rank + 1, // NumReduceDim + ReduceAdd, + PassThrough, + OutElementwiseOperation, + 256, // BlockSize_ + CShuffleBlockTransferScalarPerVector_NPerBlock, // MThreadSliceSize_ + 1, // KThreadSliceSize_ + 0, // InSrcVectorDim_ + CShuffleBlockTransferScalarPerVector_NPerBlock, // InSrcVectorSize_ + CShuffleBlockTransferScalarPerVector_NPerBlock, // OutDstVectorSize_ + decltype(DsVectorLengthSequence)>; + + struct Invoker : public BaseInvoker + { + float RunReduce(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + static constexpr index_t NumInDim = 3; + static constexpr index_t NumOutDim = 2; + + ::std::array in_lengths = {arg.KBatch, arg.M, arg.N}; + ::std::array out_lengths = {arg.M, arg.N}; + + ::std::array in_strides; + ::std::array out_strides; + if constexpr(is_same::value) + { + in_strides = {arg.M * arg.N, arg.N, 1}; + out_strides = {arg.N, 1}; + } + else + { + in_strides = {arg.M * arg.N, 1, arg.M}; + out_strides = {1, arg.M}; + } + + ::std::array reduce_dims{0}; + + ::std::array<::std::array, NumDTensor> DsLengths; + ::std::array<::std::array, NumDTensor> DsStrides; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + DsLengths[i] = out_lengths; + + using DLayout = ::std::__remove_cvref_t>; + if constexpr(is_same::value) + { + DsStrides[i] = {arg.StrideDs[i], 1}; + } + else + { + DsStrides[i] = {1, arg.StrideDs[i]}; + } + }); + + auto reduce = DeviceReduceInstance{}; + + auto argument_ptr = reduce.MakeArgumentPointer(in_lengths, + in_strides, + DsLengths, + DsStrides, + out_lengths, + out_strides, + reduce_dims, + arg.p_workspace_, + arg.p_ds, + arg.p_c_grid, + PassThrough{}, + OutElementwiseOperation{}); + + auto invoker_ptr = reduce.MakeInvokerPointer(); + + float ave_time = 0; + + if(reduce.IsSupportedArgument(argument_ptr.get())) + { + ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config); + } + else + { + throw ::std::runtime_error( + "The runtime parameters are not supported by the device instance."); + } + + return ave_time; + } + + float Run(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{}) + { + auto arg = *dynamic_cast(&arg_); + + // workspace required when doing two-kernel reduce or Ds present + const bool need_workspace = !(!(arg.IsReduceAdd() || NumDTensor > 0) && + is_same::value); + if(need_workspace) + { + if(arg.p_workspace_ == nullptr) + { + throw ::std::runtime_error("using reduce, but empty workspace!"); + } + arg.p_e_grid = reinterpret_cast(arg.p_workspace_); + } + + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw ::std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + ::std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + const auto kernel = + ::ck::kernel_gemm_wmma_cshuffle_v3; + ave_time = launch_and_time_kernel( + stream_config, kernel, ::dim3(gdx, gdy, gdz), ::dim3(BlockSize), 0, arg); + } + else + { + const auto kernel = + ::ck::kernel_gemm_wmma_cshuffle_v3; + ave_time = launch_and_time_kernel( + stream_config, kernel, ::dim3(gdx, gdy, gdz), ::dim3(BlockSize), 0, arg); + } + + if(need_workspace) + { + ave_time += RunReduce(arg_, stream_config); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_wmma_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity( + *dynamic_cast(&arg)); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return GridwiseGemm::CalculateGridSize(M, N, KBatch); + } + + static constexpr index_t GetBlockSize() { return BlockSize; } + + static size_t GetSharedMemoryNumberOfByte() + { + return GridwiseGemm::GetSharedMemoryNumberOfByte(); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const ::std::array p_ds, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + const ::std::array stride_ds, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_c, + M, + N, + K, + StrideA, + StrideB, + stride_ds, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + ::std::unique_ptr MakeInvokerPointer() override + { + return ::std::make_unique(Invoker{}); + } + + // Polymorphic interfaces + ::std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + ::std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + ::std::array DsStrides, + index_t StrideC, + index_t KSplit, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return ::std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + DsStrides, + StrideC, + KSplit, + a_element_op, + b_element_op, + c_element_op); + } + + ::std::string GetTypeString() const override + { + auto str = ::std::stringstream(); + + auto BlkGemmPipelineSchedulerToString = [](BlockGemmPipelineScheduler s) { + switch(s) + { + case BlockGemmPipelineScheduler::Intrawave: return ::std::string("Intrawave"); + case BlockGemmPipelineScheduler::Interwave: return ::std::string("Interwave"); + } + return ::std::string("?"); + }; + + auto BlkGemmPipelineVersionToString = [](BlockGemmPipelineVersion v) { + switch(v) + { + case BlockGemmPipelineVersion::v1: return ::std::string("v1"); + case BlockGemmPipelineVersion::v2: return ::std::string("v2"); + case BlockGemmPipelineVersion::v3: return ::std::string("v3"); + case BlockGemmPipelineVersion::v4: return ::std::string("v4"); + case BlockGemmPipelineVersion::v5: return ::std::string("v5"); + } + return ::std::string("v?"); + }; + + // clang-format off + str << "DeviceGemmWmmaUniversalReduce" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << ::std::string(ALayout::name)[0] + << ::std::string(BLayout::name)[0] + << ::std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"<(p_arg); + + // Need workspace if using split-K or have D tensors + if(!(!(arg.IsReduceAdd() || NumDTensor > 0) && is_same::value)) + { + return arg.M * arg.N * arg.KBatch * sizeof(ReduceDataType); + } + + return 0; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index f779909e87..b226730a09 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -3,6 +3,11 @@ #pragma once +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#include +#include +#endif + #include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" @@ -1049,6 +1054,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base { if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop + << ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages + << ") for pipeline version != v1." << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp index 7727489e51..430a4e52f4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,6 +8,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -20,6 +21,7 @@ namespace instance { using DsLayout = ck::Tuple<>; using DsDataType = ck::Tuple<>; +#ifdef CK_USE_XDL #ifdef CK_ENABLE_FP16 void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances( std::vector>>& instances); +#endif +#endif +#ifdef CK_USE_WMMA +#if defined(CK_ENABLE_FP16) +void add_device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& instances); +#endif + +#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8)) +void add_device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& instances); +#endif + +#if defined(CK_ENABLE_BF16) +void add_device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& instances); +#endif #endif template && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances( op_ptrs); add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( @@ -395,6 +445,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( op_ptrs); +#endif + +#ifdef CK_USE_WMMA + add_device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances( + op_ptrs); +#endif } } #endif @@ -406,6 +462,7 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances( op_ptrs); add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances( @@ -420,6 +477,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( op_ptrs); +#endif + +#ifdef CK_USE_WMMA + add_device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances( + op_ptrs); +#endif } } #endif @@ -430,6 +493,7 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( op_ptrs); add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances( @@ -444,6 +508,12 @@ struct DeviceOperationInstanceFactory< op_ptrs); add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( op_ptrs); +#endif + +#ifdef CK_USE_WMMA + add_device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + op_ptrs); +#endif } } #endif diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/CMakeLists.txt index 07263528b9..142ace2e42 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/CMakeLists.txt @@ -1,6 +1,7 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_UNIVERSAL_REDUCE_INSTANCES) +# XDL instances list(APPEND GEMM_UNIVERSAL_REDUCE_INSTANCES device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_bf16_i8_bf16/device_gemm_xdl_universal_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp @@ -30,4 +31,11 @@ list(APPEND GEMM_UNIVERSAL_REDUCE_INSTANCES device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp ) +# WMMA instances +list(APPEND GEMM_UNIVERSAL_REDUCE_INSTANCES + device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp + ) + add_instance_library(device_gemm_universal_reduce_instance ${GEMM_UNIVERSAL_REDUCE_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp new file mode 100644 index 0000000000..ee94046b8d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +using DsLayout = ck::Tuple<>; +using DsDataType = ck::Tuple<>; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template , + typename DsDataType = ck::Tuple<>> +using device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPerWmma|NPerWmma|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Reduce| + //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | | | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MRepeatPer|NRepeatPer| _MBlock_MRepeatPerShuffle_MWaveM| ScalarPerVector| Pipeline| Pipeline| DataType| + //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Shuffle | Shuffle | PerShuffle_NBlock_NRepeatPerShuffle| _NPerBlock | Scheduler| Version| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NWaveNPerRepeat | | | | | + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..20d88e4740 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using Row = tensor_layout::gemm::RowMajor; +using PassThrough = element_wise::PassThrough; + +void add_device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& instances) +{ + if(ck::is_gfx12_supported()) + { + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_bf16_bf16_mk_kn_mn_instances{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn.hpp new file mode 100644 index 0000000000..3ddeec3c02 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = int8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +using DsLayout = ck::Tuple<>; +using DsDataType = ck::Tuple<>; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template , + typename DsDataType = ck::Tuple<>> +using device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPerWmma|NPerWmma|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Reduce| + //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | | | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MRepeatPer|NRepeatPer| _MBlock_MRepeatPerShuffle_MWaveM| ScalarPerVector| Pipeline| Pipeline| DataType| + //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Shuffle | Shuffle | PerShuffle_NBlock_NRepeatPerShuffle| _NPerBlock | Scheduler| Version| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NWaveNPerRepeat | | | | | + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 4, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 4, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, BF16, I8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 4, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..52589a258f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_bf16_i8_bf16/device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_i8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = int8_t; +using BF16 = bhalf_t; +using Row = tensor_layout::gemm::RowMajor; +using PassThrough = element_wise::PassThrough; + +void add_device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& instances) +{ + if(ck::is_gfx12_supported()) + { + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_bf16_i8_bf16_mk_kn_mn_instances{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp new file mode 100644 index 0000000000..564b81496d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +using DsLayout = ck::Tuple<>; +using DsDataType = ck::Tuple<>; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template , + typename DsDataType = ck::Tuple<>> +using device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| DsLayout| CLayout| AData| BData| DsData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPerWmma|NPerWmma|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Reduce| + //#########################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | | | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MRepeatPer|NRepeatPer| _MBlock_MRepeatPerShuffle_MWaveM| ScalarPerVector| Pipeline| Pipeline| DataType| + //#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Shuffle | Shuffle | PerShuffle_NBlock_NRepeatPerShuffle| _NPerBlock | Scheduler| Version| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NWaveNPerRepeat | | | | | + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 1, 1, S<1, 32, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float>, + DeviceGemm_Wmma_CShuffleV3R1< Row, Row, DsLayout, Row, F16, F16, DsDataType, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, float> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..3663ee6529 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using Row = tensor_layout::gemm::RowMajor; +using PassThrough = element_wise::PassThrough; +using Add = element_wise::Add; + +using DsLayout_F16 = ck::Tuple<>; +using DsDataType_F16 = ck::Tuple<>; + +void add_device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& instances) +{ + if(ck::is_gfx12_supported()) + { + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_gemm_wmma_universal_reduce_f16_f16_f16_mk_kn_mn_instances{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp index a0ee6a6674..32d2b38def 100644 --- a/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp @@ -10,6 +10,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp" @@ -86,10 +87,21 @@ bool profile_gemm_universal_reduce_impl(int do_verification, switch(init_method) { - case 0: break; + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + a_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); break; default: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index ce8e652339..5538307232 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -68,7 +68,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_gemm_splitk.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp) list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp) - list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp) list(APPEND PROFILER_OPS profile_gemm_universal_streamk.cpp) list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu.cpp) list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp) @@ -90,6 +89,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND PROFILER_OPS profile_gemm_universal.cpp) list(APPEND PROFILER_OPS profile_batched_gemm.cpp) list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp) + list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) @@ -185,7 +185,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance) - list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_streamk_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_instance) @@ -221,6 +220,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND DEVICE_INSTANCES device_gemm_universal_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_instance) list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance) + list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 947d5136be..c292400878 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -248,6 +248,7 @@ add_subdirectory(gemm_universal) add_subdirectory(gemm_b_scale) add_subdirectory(gemm_universal_streamk) add_subdirectory(gemm_reduce) +add_subdirectory(gemm_universal_reduce) add_subdirectory(batched_gemm) add_subdirectory(batched_gemm_reduce) add_subdirectory(batched_gemm_gemm) diff --git a/test/gemm_universal_reduce/CMakeLists.txt b/test/gemm_universal_reduce/CMakeLists.txt new file mode 100644 index 0000000000..dab9de44c0 --- /dev/null +++ b/test/gemm_universal_reduce/CMakeLists.txt @@ -0,0 +1,14 @@ +add_gtest_executable(test_gemm_universal_reduce_bf16_wmma test_gemm_universal_reduce_bf16_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_reduce_bf16_wmma PRIVATE utility device_gemm_universal_reduce_instance) +endif() + +add_gtest_executable(test_gemm_universal_reduce_fp16_wmma test_gemm_universal_reduce_fp16_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_reduce_fp16_wmma PRIVATE utility device_gemm_universal_reduce_instance) +endif() + +add_gtest_executable(test_gemm_universal_reduce_bf16A_i8_wmma test_gemm_universal_reduce_bf16A_i8_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_reduce_bf16A_i8_wmma PRIVATE utility device_gemm_universal_reduce_instance) +endif() diff --git a/test/gemm_universal_reduce/test_gemm_universal_reduce_bf16A_i8_wmma.cpp b/test/gemm_universal_reduce/test_gemm_universal_reduce_bf16A_i8_wmma.cpp new file mode 100644 index 0000000000..ec4c0dc784 --- /dev/null +++ b/test/gemm_universal_reduce/test_gemm_universal_reduce_bf16A_i8_wmma.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "profiler/profile_gemm_universal_reduce_impl.hpp" + +TEST(GemmUniversalReduce, BF16A_I8) +{ + using Row = ck::tensor_layout::gemm::RowMajor; + + int M = 512; + int N = 256; + int K = 128; + int KBatch = 1; + + bool pass = true; + + pass = pass && ck::profiler::profile_gemm_universal_reduce_impl, + float, + ck::bhalf_t, + Row, + Row, + ck::Tuple<>, + Row>( + true, 3, false, true, M, N, K, K, N, N, KBatch, 1, 10); + EXPECT_TRUE(pass); +} diff --git a/test/gemm_universal_reduce/test_gemm_universal_reduce_bf16_wmma.cpp b/test/gemm_universal_reduce/test_gemm_universal_reduce_bf16_wmma.cpp new file mode 100644 index 0000000000..cbc7860fd9 --- /dev/null +++ b/test/gemm_universal_reduce/test_gemm_universal_reduce_bf16_wmma.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "profiler/profile_gemm_universal_reduce_impl.hpp" + +TEST(GemmUniversalReduce, BF16) +{ + using Row = ck::tensor_layout::gemm::RowMajor; + + int M = 512; + int N = 256; + int K = 128; + int KBatch = 1; + + bool pass = true; + + pass = pass && ck::profiler::profile_gemm_universal_reduce_impl, + float, + ck::bhalf_t, + Row, + Row, + ck::Tuple<>, + Row>( + true, 1, false, true, M, N, K, K, N, N, KBatch, 1, 10); + EXPECT_TRUE(pass); +} diff --git a/test/gemm_universal_reduce/test_gemm_universal_reduce_fp16_wmma.cpp b/test/gemm_universal_reduce/test_gemm_universal_reduce_fp16_wmma.cpp new file mode 100644 index 0000000000..731bee89ed --- /dev/null +++ b/test/gemm_universal_reduce/test_gemm_universal_reduce_fp16_wmma.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "profiler/profile_gemm_universal_reduce_impl.hpp" + +TEST(GemmUniversalReduce, FP16) +{ + using Row = ck::tensor_layout::gemm::RowMajor; + + int M = 512; + int N = 256; + int K = 128; + int KBatch = 1; + + bool pass = true; + + pass = pass && ck::profiler::profile_gemm_universal_reduce_impl, + float, + ck::half_t, + Row, + Row, + ck::Tuple<>, + Row>( + true, 1, false, true, M, N, K, K, N, N, KBatch, 1, 10); + EXPECT_TRUE(pass); +}