diff --git a/example/60_gemm_multi_ABD/CMakeLists.txt b/example/60_gemm_multi_ABD/CMakeLists.txt index a9e0d3f9ad..ffc6cec61d 100644 --- a/example/60_gemm_multi_ABD/CMakeLists.txt +++ b/example/60_gemm_multi_ABD/CMakeLists.txt @@ -1,3 +1,7 @@ +add_example_executable(example_gemm_multi_ABD_wmma_fp16 gemm_multi_ABD_wmma_fp16.cpp) +add_example_executable(example_gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8 gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp) +add_example_executable(example_gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8 gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp) +add_example_executable(example_gemm_multi_ABD_wmma_fastgelu_bf16_i8 gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp) add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) add_example_executable(example_gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8 gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp) add_example_executable(example_gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8 gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp) diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp new file mode 100644 index 0000000000..a30314f58c --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp @@ -0,0 +1,307 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + AsLayout, + BsLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<8, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<8, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 2; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, StrideB}, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp new file mode 100644 index 0000000000..086a0f4834 --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp @@ -0,0 +1,299 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = FastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + AsLayout, + BsLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<8, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<8, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 2; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{0, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 0; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, StrideB}, + std::array{}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp new file mode 100644 index 0000000000..32345d1263 --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp @@ -0,0 +1,362 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using ELayout = Row; + +struct AddScale +{ + static constexpr auto I0 = ck::Number<0>{}; + static constexpr auto I1 = ck::Number<1>{}; + static constexpr auto I2 = ck::Number<2>{}; + static constexpr auto I3 = ck::Number<3>{}; + + __host__ __device__ constexpr void + operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const + { + const auto a0_v_t = ck::vector_type{a0}; + const auto a1_v_t = ck::vector_type{a1}; + + auto r_v_t = ck::vector_type{}; + + r_v_t.AsType()(I0) = + scale * (a0_v_t.AsType()[I0] + a1_v_t.AsType()[I0]); + r_v_t.AsType()(I1) = + scale * (a0_v_t.AsType()[I1] + a1_v_t.AsType()[I1]); + r_v_t.AsType()(I2) = + scale * (a0_v_t.AsType()[I2] + a1_v_t.AsType()[I2]); + r_v_t.AsType()(I3) = + scale * (a0_v_t.AsType()[I3] + a1_v_t.AsType()[I3]); + + a = r_v_t.AsType()[I0]; + } + + __host__ __device__ constexpr void + operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const + { + a = scale * (a0 + a1); + } + + // this attribute controls the copy_function applying element_wise_op with + // pack4_data + constexpr const static bool is_pack4_invocable = true; + + float scale = 1.0; +}; + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +using AElementOp = AddScale; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ELayout, + ck::Tuple, + ck::Tuple, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 256, + 128, + 32, + 8, + 8, + 16, + 16, + 4, + 4, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 64, 1, 4>, + S<8, 8, 8>>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 12: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{0.2}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}, + std::array{b_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA, StrideA}, + std::array{StrideB}, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp new file mode 100644 index 0000000000..00e2d7e33c --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = BF16; +using D1DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using D1Layout = D0Layout; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyAddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + AsLayout, + BsLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<8, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<8, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 1; + constexpr ck::index_t NumDTensor = 2; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer()}, + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + std::array{StrideD, StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index 08b3aba2b3..5da447125e 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,6 +15,151 @@ namespace ck { namespace utility { +template +struct RotatingMemWrapperMultiABD +{ + static constexpr index_t NumAs = AsDataType::Size(); + static constexpr index_t NumBs = BsDataType::Size(); + static constexpr index_t NumDs = DsDataType::Size(); + + using AsGridPointer = decltype(Argument::p_as_grid); + using BsGridPointer = decltype(Argument::p_bs_grid); + using DsGridPointer = decltype(Argument::p_ds_grid); + + RotatingMemWrapperMultiABD() = delete; + RotatingMemWrapperMultiABD(Argument& arg_, + std::size_t rotating_count_, + std::array size_as_, + std::array size_bs_, + std::array size_ds_) + : arg(arg_), + rotating_count(rotating_count_), + size_as(size_as_), + size_bs(size_bs_), + size_ds(size_ds_) + { + p_as_grids.push_back(arg.p_as_grid); + p_bs_grids.push_back(arg.p_bs_grid); + p_ds_grids.push_back(arg.p_ds_grid); + for(size_t i = 1; i < rotating_count; i++) + { + { + AsGridPointer as_buffer; + static_for<0, NumAs, 1>{}([&](auto j) { + void* pADeviceBuf; + hip_check_error(hipMalloc(static_cast(&pADeviceBuf), size_as_[j])); + hip_check_error(hipMemcpy(static_cast(pADeviceBuf), + static_cast(p_as_grids[0][j]), + size_as_[j], + hipMemcpyDeviceToDevice)); + using ADataType = remove_cvref_t>; + + as_buffer(j) = static_cast(pADeviceBuf); + }); + p_as_grids.push_back(as_buffer); + } + + { + BsGridPointer bs_buffer; + static_for<0, NumBs, 1>{}([&](auto j) { + void* pBDeviceBuf; + hip_check_error(hipMalloc(static_cast(&pBDeviceBuf), size_bs_[j])); + hip_check_error(hipMemcpy(static_cast(pBDeviceBuf), + static_cast(p_bs_grids[0][j]), + size_bs_[j], + hipMemcpyDeviceToDevice)); + using BDataType = remove_cvref_t>; + + bs_buffer(j) = static_cast(pBDeviceBuf); + }); + p_bs_grids.push_back(bs_buffer); + } + + { + DsGridPointer ds_buffer; + static_for<0, NumDs, 1>{}([&](auto j) { + void* pDDeviceBuf; + hip_check_error(hipMalloc(static_cast(&pDDeviceBuf), size_ds_[j])); + hip_check_error(hipMemcpy(static_cast(pDDeviceBuf), + static_cast(p_ds_grids[0][j]), + size_ds_[j], + hipMemcpyDeviceToDevice)); + + using DDataType = remove_cvref_t>; + + ds_buffer(j) = static_cast(pDDeviceBuf); + }); + + p_ds_grids.push_back(ds_buffer); + } + } + } + + void Next() + { + if(rotating_count > 1) + { + std::size_t idx = iter++ % rotating_count; + arg.p_as_grid = p_as_grids[idx]; + arg.p_bs_grid = p_bs_grids[idx]; + arg.p_ds_grid = p_ds_grids[idx]; + } + } + void Print() + { + std::cout << "RotatingMemWrapperMultiD: { size_a: {"; + static_for<0, NumAs, 1>{}( + [&](auto j) { std::cout << size_as[j] << (j.value < NumAs - 1 ? ", " : ""); }); + std::cout << "}, size_b: {"; + static_for<0, NumBs, 1>{}( + [&](auto j) { std::cout << size_bs[j] << (j.value < NumBs - 1 ? ", " : ""); }); + std::cout << "}, rotating_count: " << rotating_count << "}" << std::endl; + } + ~RotatingMemWrapperMultiABD() + { + if(rotating_count > 1) + { + // restore ptr + arg.p_as_grid = p_as_grids[0]; + arg.p_bs_grid = p_bs_grids[0]; + arg.p_ds_grid = p_ds_grids[0]; + + // free device mem + for(size_t i = 1; i < rotating_count; i++) + { + static_for<0, NumAs, 1>{}([&](auto j) { + using ADataType = remove_cvref_t>; + hip_check_error( + hipFree(static_cast(const_cast(p_as_grids[i][j])))); + }); + + static_for<0, NumBs, 1>{}([&](auto j) { + using BDataType = remove_cvref_t>; + hip_check_error( + hipFree(static_cast(const_cast(p_bs_grids[i][j])))); + }); + + static_for<0, NumDs, 1>{}([&](auto j) { + using DDataType = remove_cvref_t>; + hip_check_error( + hipFree(static_cast(const_cast(p_ds_grids[i][j])))); + }); + } + } + } + + private: + Argument& arg; + std::size_t iter = 0; + std::size_t rotating_count = 1; + std::array size_as = {0}; + std::array size_bs = {0}; + std::array size_ds = {0}; + std::vector p_as_grids; + std::vector p_bs_grids; + std::vector p_ds_grids; +}; + template struct RotatingMemWrapperMultiD { @@ -318,6 +463,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, // total_time += cur_time; // #endif +#if !defined(CK_USE_WMMA) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { // std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; @@ -326,6 +472,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, static_cast(gemm_args.p_a_grid), static_cast(gemm_args.p_b_grid)); } +#endif } hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); hip_check_error(hipEventSynchronize(stop)); diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp index cbb9fadc6d..5de33c90fe 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -55,6 +55,155 @@ struct DeviceGemmMultipleABD : public BaseOperator virtual std::unique_ptr MakeInvokerPointer() = 0; }; +// GEMM: +// input : A0[M, K], B0[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleABDSplitK : public BaseOperator +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + ck::index_t StrideE, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +/// @brief Wrapper for backward compatibility that allows to use instances of +/// DeviceGemmMultipleABDSplitK in contexts where DeviceGemmMultipleABD is expected. +/// +/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances(). +/// The only difference between API of DeviceGemmMultipleABD and DeviceGemmMultipleABDSplitK +/// is that DeviceGemmMultipleABDSplitK::MakeArgumentPointer requires an additional parameter +/// KBatch which is explicitly passed as 1 by this wrapper. +template +struct DeviceGemmMultipleABDSplitKWrapper : public DeviceGemmMultipleABD +{ + + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef __HIPCC_RTC__ + + explicit DeviceGemmMultipleABDSplitKWrapper(std::unique_ptr p_op) + : p_op_(std::move(p_op)) + { + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return p_op_->IsSupportedArgument(p_arg); + } + std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return p_op_->MakeArgumentPointer(p_as, + p_bs, + p_ds, + p_e, + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + 1, // KBatch + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return p_op_->MakeInvokerPointer(); + } + + std::string GetTypeString() const override { return p_op_->GetTypeString(); } + + private: + std::unique_ptr p_op_; + +#endif // __HIPCC_RTC__ +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp index c00078186f..e305dbfd9a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -64,9 +64,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + // shift A matrices pointer for splitk + typename GridwiseGemm::AsGridPointer p_as_grid_shift; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ADataType_ = + remove_cvref_t>; + p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + + splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; + }); + + // shift B matrices pointer for splitk + typename GridwiseGemm::BsGridPointer p_bs_grid_shift; + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BDataType_ = + remove_cvref_t>; + p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + + splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; + }); + GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset + a_batch_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset + b_batch_offset, + p_as_grid_shift, + p_bs_grid_shift, karg.p_ds_grid, karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, p_shared, @@ -278,8 +296,8 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm, // DsLayout CLayout, - ADataType, - BDataType, + Tuple, + Tuple, AccDataType, CShuffleDataType, Tuple<>, // DsDataType @@ -346,15 +364,15 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm{p_a_grid_}, + std::array{p_b_grid_}, std::array{}, // p_ds_grid_ p_c_grid_, M_, N_, K_, - StrideA_, - StrideB_, + std::array{StrideA_}, + std::array{StrideB_}, std::array{}, // StrideDs_ StrideC_, k_batch_, @@ -423,26 +441,33 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm rotating_mem( - arg_, - stream_config.rotating_count, - arg_.Batch * size_a_buffer, - arg_.Batch * size_b_buffer); + std::array size_as_buffers; + size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch; + + std::array size_bs_buffers; + size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch; + + ck::utility::RotatingMemWrapperMultiABD, + Tuple, + Tuple<>> + rotating_mem(arg_, + stream_config.rotating_count, + size_as_buffers, + size_bs_buffers, + std::array{}); rotating_mem.Print(); auto run_flush_cache = [&]() { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..48914479bc --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp @@ -0,0 +1,422 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// @brief \"Universal\" GEMM operation with SplitK support and multiple D tensors. +/// +/// @par Overview +/// This GEMM operation implements the following mathematical equation: +/// E{M,N} = CDE_op(A_op(As{M,K}...) * B_op(Bs{K,N}...), Ds{M,N}...) +/// Where As, Bs, Ds are input tensors and E is the output tensor. The A/B_op are +/// elementwise +// operations that could be applied on each tensor respectively. The CDE_op is an +// elementwise operation applied to the C and all D tensors. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam AsLayout A tensors data layouts. +/// @tparam BsLayout B tensors data layouts. +/// @tparam DsLayout D tensors data layouts. +/// @tparam ELayout E tensor data layout. +/// @tparam AsDataType A tensors data types. +/// @tparam BsDataType B tensors data types. +/// @tparam DsDataType D tensors data types. +/// @tparam EDataType E tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after +/// GEMM) and D input tensors. +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1 The vector load size from global memory for A tensor. +/// @tparam BK1 The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access. +/// Used when loading data from D tensors and storing data +/// to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). +template +struct DeviceGemmMultipleABD_Wmma_CShuffleV3 + : public DeviceGemmMultipleABDSplitK +{ + // Note: Pass multiple layout but then using only the first one + // This is to replicate xdl functionality but it should be extended + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + using DeviceGemmCommon = + DeviceGemm_Wmma_CShuffleV3_Common; + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + static_cast(p_e), + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_as, + p_bs, + p_ds, + static_cast(p_e), + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmMultipleABD_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", "; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ALayout_ = remove_cvref_t>; + + str << std::string(ALayout_::name)[0]; + }); + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BLayout_ = remove_cvref_t>; + + str << std::string(BLayout_::name)[0]; + }); + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + str << std::string(DLayout::name)[0]; + }); + str << std::string(ELayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x"<, + Tuple, AccDataType, CShuffleDataType, DsDataType, @@ -244,8 +244,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3 using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common, + Tuple, DsDataType, EDataType, MPerBlock, @@ -291,15 +291,15 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3 BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) { - return Argument{static_cast(p_a), - static_cast(p_b), + return Argument{std::array{p_a}, + std::array{p_b}, p_ds, static_cast(p_e), M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, StrideDs, StrideE, KBatch, @@ -328,15 +328,15 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3 BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override { - return std::make_unique(static_cast(p_a), - static_cast(p_b), + return std::make_unique(std::array{p_a}, + std::array{p_b}, p_ds, static_cast(p_e), M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, StrideDs, StrideE, KBatch, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp index f1eb5e5d64..2ceeb39bac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp @@ -182,8 +182,8 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2, // DsLayout CLayout, - ADataType, - BDataType, + Tuple, + Tuple, AccDataType, CShuffleDataType, Tuple<>, // DsDataType @@ -233,8 +233,8 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2, + Tuple, Tuple<>, CDataType, MPerBlock, @@ -283,15 +283,15 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2{p_a}, + std::array{p_b}, std::array{}, // p_ds_grid_ p_c, M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, std::array{}, // StrideDs_ StrideC, KBatch, @@ -317,15 +317,15 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2(static_cast(p_a), - static_cast(p_b), + return std::make_unique(std::array{p_a}, + std::array{p_b}, std::array{}, // p_ds_grid_ static_cast(p_c), M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, std::array{}, // StrideDs_ StrideC, KBatch, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp index a9d5c666a9..5e9a861f41 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -91,8 +91,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale, // DsLayout CLayout, - ADataType, - BDataType, + Tuple, + Tuple, + BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, // DsDataType @@ -144,8 +145,8 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale, + Tuple, Tuple<>, CDataType, MPerBlock, @@ -195,15 +196,15 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale{p_a}, + std::array{p_b}, std::array{}, // p_ds_grid_ p_c, M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, std::array{}, // StrideDs_ StrideC, StrideScaleB, @@ -233,15 +234,15 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale(static_cast(p_a), - static_cast(p_b), + return std::make_unique(std::array{p_a}, + std::array{p_b}, std::array{}, // p_ds_grid_ static_cast(p_c), M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, std::array{}, // StrideDs_ StrideC, StrideScaleB, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index 72191632d8..4269d67d12 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -23,8 +23,8 @@ namespace tensor_operation { namespace device { template size_as_buffers; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + size_as_buffers[i] = a_grid_desc_ak0_m_ak1[i].GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize; + }); + + std::array size_bs_buffers; + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + size_bs_buffers[i] = b_grid_desc_bk0_n_bk1[i].GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize; + }); const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); @@ -108,12 +117,13 @@ struct DeviceGemm_Wmma_CShuffleV3_Common ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, - stream_config.rotating_count, - size_a_buffer, - size_b_buffer, - size_ds_buffers); + ck::utility:: + RotatingMemWrapperMultiABD + rotating_mem(arg_, + stream_config.rotating_count, + size_as_buffers, + size_bs_buffers, + size_ds_buffers); rotating_mem.Print(); auto run_flush_cache = [&]() { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp index 3a06ea8451..df51a2aa27 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp @@ -98,8 +98,8 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1, CLayout, - ADataType, - BDataType, + Tuple, + Tuple, GemmAccDataType, ReduceDataType, Tuple<>, @@ -147,15 +147,15 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1 p_a_grid_, + std::array p_b_grid_, const ::std::array p_ds_, CDataType* p_c_grid_, index_t M_, index_t N_, index_t K_, - index_t StrideA_, - index_t StrideB_, + std::array StrideA_, + std::array StrideB_, const ::std::array stride_ds_, index_t StrideC_, index_t KBatch_, @@ -430,15 +430,15 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1{p_a}, + std::array{p_b}, p_ds, p_c, M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, stride_ds, StrideC, KBatch, @@ -472,15 +472,15 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1(static_cast(p_a), - static_cast(p_b), + return ::std::make_unique(std::array{p_a}, + std::array{p_b}, p_ds, static_cast(p_c), M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, DsStrides, StrideC, KSplit, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index bd2a8b04bc..d226510cf0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -11,6 +11,7 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -39,8 +40,8 @@ namespace ck { /// @tparam BLayout B tensor data layout. /// @tparam DsLayout D tensors data layouts. /// @tparam ELayout E tensor data layout. -/// @tparam ADataType A tensor data type. -/// @tparam BDataType B tensor data type. +/// @tparam AsDataType A tensors data types. +/// @tparam BsDataType B tensors data types. /// @tparam AccDataType The accumulation data type related to the hardware /// matrix-multiplication instruction. /// @tparam CShuffleDataType The data type used to store matrix-multiplication results into @@ -129,8 +130,8 @@ template StrideAs_, + std::array StrideBs_, std::array StrideDs_, index_t StrideE_, index_t KBatch_) : M{M_}, N{N_}, K{K_}, - StrideA{StrideA_}, - StrideB{StrideB_}, + StrideAs{StrideAs_}, + StrideBs{StrideBs_}, StrideDs{StrideDs_}, StrideE{StrideE_}, KBatch{KBatch_}, @@ -355,7 +362,15 @@ struct GridwiseGemm_wmma_cshuffle_v3 __host__ void Print() const { std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " - << "SA:" << StrideA << ", " << "SB:" << StrideB << ", "; + << "SAs: {"; + static_for<0, NumATensor, 1>{}([&](auto i) { + std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : ""); + }); + std::cout << "}, " << "SBs: {"; + static_for<0, NumBTensor, 1>{}([&](auto i) { + std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : ""); + }); + std::cout << "}, "; if constexpr(NumDTensor > 0) { std::cout << "SDs: { "; @@ -373,8 +388,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 index_t M; index_t N; index_t K; - index_t StrideA; - index_t StrideB; + std::array StrideAs; + std::array StrideBs; std::array StrideDs; index_t StrideE; index_t KBatch; @@ -391,15 +406,15 @@ struct GridwiseGemm_wmma_cshuffle_v3 // Argument struct Argument : public tensor_operation::device::BaseArgument, public Problem { - __host__ Argument(const ADataType* p_a_grid_, - const BDataType* p_b_grid_, + __host__ Argument(std::array p_as_grid_, + std::array p_bs_grid_, std::array p_ds_grid_, EDataType* p_e_grid_, index_t M_, index_t N_, index_t K_, - index_t StrideA_, - index_t StrideB_, + std::array StrideAs_, + std::array StrideBs_, std::array StrideDs_, index_t StrideE_, index_t k_batch_, @@ -407,9 +422,9 @@ struct GridwiseGemm_wmma_cshuffle_v3 BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_ = false) - : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_}, - p_a_grid{p_a_grid_}, - p_b_grid{p_b_grid_}, + : Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_}, + p_as_grid{}, + p_bs_grid{}, p_ds_grid{}, p_e_grid{p_e_grid_}, a_element_op{a_element_op_}, @@ -417,9 +432,27 @@ struct GridwiseGemm_wmma_cshuffle_v3 cde_element_op{cde_element_op_}, is_reduce(is_reduce_) { + // populate pointer, desc for As + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + + // A pointer + p_as_grid(i) = static_cast(p_as_grid_[i]); + }); + + // populate pointer, desc for Bs + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + + // B pointer + p_bs_grid(i) = static_cast(p_bs_grid_[i]); + }); + + // populate pointer, desc for Ds static_for<0, NumDTensor, 1>{}([&](auto i) { using DDataType = remove_cvref_t>; + // D pointer p_ds_grid(i) = static_cast(p_ds_grid_[i]); }); } @@ -434,8 +467,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 return (Problem::KBatch > 1) && (!is_reduce); } - const ADataType* p_a_grid; - const BDataType* p_b_grid; + AsGridPointer p_as_grid; + BsGridPointer p_bs_grid; DsGridPointer p_ds_grid; EDataType* p_e_grid; @@ -452,29 +485,39 @@ struct GridwiseGemm_wmma_cshuffle_v3 __device__ SplitKBatchOffset(Argument& karg, index_t k_id) { + // Note: in xdl implementation multiple AB supports one layout + // but multiple strides, so we create an array of offsets with + // the same values. + // It should be fixed later on. Once we will have a thread transfer + // more flexible. if constexpr(is_same_v) { - a_k_split_offset = k_id * karg.KRead / APackedSize; + static_for<0, NumATensor, 1>{}( + [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; }); } else if constexpr(is_same_v) { - a_k_split_offset = k_id * karg.KRead * karg.StrideA; + static_for<0, NumATensor, 1>{}( + [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; }); } if constexpr(is_same_v) { - b_k_split_offset = k_id * karg.KRead * karg.StrideB; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; }); } else if constexpr(is_same_v) { if constexpr(!PermuteB) { - b_k_split_offset = k_id * karg.KRead / BPackedSize; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; }); } else { const int k0_offset = karg.KRead * karg.N; - b_k_split_offset = k_id * k0_offset / BPackedSize; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; }); } } @@ -497,8 +540,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 } } - index_t a_k_split_offset; - index_t b_k_split_offset; + std::array a_k_split_offset; + std::array b_k_split_offset; index_t c_reduce_offset; }; @@ -514,8 +557,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 template - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, DsGridPointer& p_ds_grid, EDataType* p_e_grid, void* p_shared, @@ -524,10 +567,10 @@ struct GridwiseGemm_wmma_cshuffle_v3 BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) { - const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( - problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); - const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( - problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); + const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0); const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N( @@ -562,20 +605,20 @@ struct GridwiseGemm_wmma_cshuffle_v3 const index_t num_k_block_per_scale = GetKBlockPerScale(); - Base::template Run(p_a_grid, - p_b_grid, + TailNum>(p_as_grid, + p_bs_grid, p_ds_grid, p_e_grid, p_shared, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, ds_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, @@ -595,10 +638,26 @@ struct GridwiseGemm_wmma_cshuffle_v3 __device__ static void Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg) { + // shift A matrices pointer for splitk + AsGridPointer p_as_grid_splitk; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + p_as_grid_splitk(i) = static_cast(karg.p_as_grid[i]) + + splitk_batch_offset.a_k_split_offset[i]; + }); + + // shift B matrices pointer for splitk + BsGridPointer p_bs_grid_splitk; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + p_bs_grid_splitk(i) = static_cast(karg.p_bs_grid[i]) + + splitk_batch_offset.b_k_split_offset[i]; + }); + Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset, + p_as_grid_splitk, + p_bs_grid_splitk, + karg.p_ds_grid, karg.p_e_grid + splitk_batch_offset.c_reduce_offset, p_shared, karg, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp index 29c5ae31cd..46de6b156a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -22,8 +22,9 @@ template { - using BScaleType = ck::half_t; - using Base = GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, DsLayout, ELayout, - ADataType, - BDataType, + AsDataType, + BsDataType, AccDataType, CShuffleDataType, DsDataType, @@ -202,8 +201,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale using Base::CalculateMPadded; using Base::CalculateNBlock; using Base::CalculateNPadded; - using Base::MakeAGridDescriptor_AK0_M_AK1; - using Base::MakeBGridDescriptor_BK0_N_BK1; + using Base::MakeAsGridDescriptor_AK0_M_AK1; + using Base::MakeBsGridDescriptor_BK0_N_BK1; using Base::MakeDEGridDescriptor_M_N; using Base::MakeDsGridDescriptor_M_N; using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; @@ -217,7 +216,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::NumATensor; + using Base::NumBTensor; using Base::NumDTensor; + using typename Base::AsGridPointer; + using typename Base::BsGridPointer; using typename Base::DsGridPointer; struct Problem @@ -225,8 +228,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale __host__ Problem(index_t M_, index_t N_, index_t K_, - index_t StrideA_, - index_t StrideB_, + std::array StrideAs_, + std::array StrideBs_, std::array StrideDs_, index_t StrideE_, index_t StrideScaleB_, @@ -234,8 +237,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale : M{M_}, N{N_}, K{K_}, - StrideA{StrideA_}, - StrideB{StrideB_}, + StrideAs{StrideAs_}, + StrideBs{StrideBs_}, StrideDs{StrideDs_}, StrideE{StrideE_}, StrideScaleB{StrideScaleB_}, @@ -254,7 +257,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale __host__ void Print() const { std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " - << "SA:" << StrideA << ", " << "SB:" << StrideB << ", "; + << "SAs: {"; + static_for<0, NumATensor, 1>{}([&](auto i) { + std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : ""); + }); + std::cout << "}, " << "SBs: {"; + static_for<0, NumBTensor, 1>{}([&](auto i) { + std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : ""); + }); + std::cout << "}, "; if constexpr(NumDTensor > 0) { std::cout << "SDs: { "; @@ -273,8 +284,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale index_t M; index_t N; index_t K; - index_t StrideA; - index_t StrideB; + std::array StrideAs; + std::array StrideBs; std::array StrideDs; index_t StrideE; index_t StrideScaleB; @@ -292,15 +303,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale // Argument struct Argument : public tensor_operation::device::BaseArgument, public Problem { - __host__ Argument(const ADataType* p_a_grid_, - const BDataType* p_b_grid_, + __host__ Argument(std::array p_as_grid_, + std::array p_bs_grid_, std::array p_ds_grid_, EDataType* p_e_grid_, index_t M_, index_t N_, index_t K_, - index_t StrideA_, - index_t StrideB_, + std::array StrideAs_, + std::array StrideBs_, std::array StrideDs_, index_t StrideE_, index_t StrideScaleB_, @@ -310,9 +321,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_ = false) - : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, StrideScaleB_, k_batch_}, - p_a_grid{p_a_grid_}, - p_b_grid{p_b_grid_}, + : Problem{M_, + N_, + K_, + StrideAs_, + StrideBs_, + StrideDs_, + StrideE_, + StrideScaleB_, + k_batch_}, + p_as_grid{}, + p_bs_grid{}, p_ds_grid{}, p_e_grid{p_e_grid_}, p_b_scale_grid{p_b_scale_grid_}, @@ -321,6 +340,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale cde_element_op{cde_element_op_}, is_reduce(is_reduce_) { + // populate pointer, desc for As + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + + // A pointer + p_as_grid(i) = static_cast(p_as_grid_[i]); + }); + + // populate pointer, desc for Bs + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + + // B pointer + p_bs_grid(i) = static_cast(p_bs_grid_[i]); + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { using DDataType = remove_cvref_t>; @@ -338,8 +373,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale return (Problem::KBatch > 1) && (!is_reduce); } - const ADataType* p_a_grid; - const BDataType* p_b_grid; + AsGridPointer p_as_grid; + BsGridPointer p_bs_grid; DsGridPointer p_ds_grid; EDataType* p_e_grid; @@ -355,29 +390,39 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale __device__ SplitKBatchOffset(Argument& karg, index_t k_id) { + // Note: in xdl implementation multiple AB supports one layout + // but multiple strides, so we create an array of offsets with + // the same values. + // It should be fixed later on. Once we will have a thread transfer + // more flexible. if constexpr(is_same_v) { - a_k_split_offset = k_id * karg.KRead / APackedSize; + static_for<0, NumATensor, 1>{}( + [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; }); } else if constexpr(is_same_v) { - a_k_split_offset = k_id * karg.KRead * karg.StrideA; + static_for<0, NumATensor, 1>{}( + [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; }); } if constexpr(is_same_v) { - b_k_split_offset = k_id * karg.KRead * karg.StrideB; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; }); } else if constexpr(is_same_v) { if constexpr(!PermuteB) { - b_k_split_offset = k_id * karg.KRead / BPackedSize; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; }); } else { const int k0_offset = karg.KRead * karg.N; - b_k_split_offset = k_id * k0_offset / BPackedSize; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; }); } } @@ -410,8 +455,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale } } - index_t a_k_split_offset; - index_t b_k_split_offset; + std::array a_k_split_offset; + std::array b_k_split_offset; index_t scale_k_split_offset; // New member for scale matrix offset index_t c_reduce_offset; }; @@ -423,7 +468,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; - template + template __device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, const BScaleType* p_b_scale_grid, index_t block_n_id) @@ -488,8 +533,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale template - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, DsGridPointer& p_ds_grid, EDataType* p_e_grid, const BScaleType* p_b_scale_grid, @@ -499,10 +544,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) { - const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( - problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); - const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( - problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); + const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0); const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N( @@ -542,20 +587,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale const index_t num_k_block_per_scale = GetKBlockPerScale(); - Base::template Run(p_a_grid, - p_b_grid, + TailNum>(p_as_grid, + p_bs_grid, p_ds_grid, p_e_grid, p_shared, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, ds_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, @@ -575,10 +620,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale __device__ static void Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg) { + // shift A matrices pointer for splitk + AsGridPointer p_as_grid_splitk; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + p_as_grid_splitk(i) = static_cast(karg.p_as_grid[i]) + + splitk_batch_offset.a_k_split_offset[i]; + }); + + // shift B matrices pointer for splitk + BsGridPointer p_bs_grid_splitk; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + p_bs_grid_splitk(i) = static_cast(karg.p_bs_grid[i]) + + splitk_batch_offset.b_k_split_offset[i]; + }); + Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset, + p_as_grid_splitk, + p_bs_grid_splitk, + karg.p_ds_grid, karg.p_e_grid + splitk_batch_offset.c_reduce_offset, karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, p_shared, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 59d3a6a4c5..dac0c9b3b0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -16,6 +16,7 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -61,8 +62,8 @@ template {}; static constexpr auto I7 = Number<7>{}; + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + + using LDSTypeA = + typename std::conditional<(NumATensor > 1), + ComputeTypeA, + remove_cvref_t>>::type; + using LDSTypeB = + typename std::conditional<(NumBTensor > 1), + ComputeTypeB, + remove_cvref_t>>::type; + static constexpr auto EShuffleBlockTransferScalarPerVector = CDEShuffleBlockTransferScalarPerVectors{}[I0]; @@ -136,14 +149,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base using ThisThreadBlock = ThisThreadBlock; static constexpr index_t APackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) + if constexpr(is_same_v, pk_i4_t>) return 2; else return 1; }(); static constexpr index_t BPackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) + if constexpr(is_same_v, pk_i4_t>) return 2; else return 1; @@ -230,6 +243,31 @@ struct GridwiseGemm_wmma_cshuffle_v3_base make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); } + static constexpr auto MakeAsGridPointer() + { + return generate_tuple( + [&](auto i) { + using ADataType_ = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeBsGridPointer() + { + return generate_tuple( + [&](auto i) { + using BDataType_ = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using AsGridPointer = decltype(MakeAsGridPointer()); + using BsGridPointer = decltype(MakeBsGridPointer()); + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) { @@ -314,6 +352,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } + __host__ __device__ static auto + MakeAsGridDescriptor_AK0_M_AK1(const index_t M, + const index_t MPad, + const index_t K, + const index_t KPad, + const std::array& StrideAs, + const index_t AK0) + { + return generate_tuple( + [&](auto i) { + return MakeAGridDescriptor_AK0_M_AK1(M, MPad, K, KPad, StrideAs[i], AK0); + }, + Number{}); + } + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) { @@ -330,7 +383,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base using GemmSpecialization = tensor_operation::device::GemmSpecialization; - static_assert(!(is_same_v, pk_i4_t> && + static_assert(!(is_same_v, pk_i4_t> && GemmSpec != GemmSpecialization::Default), "pk_i4_t does not support padding"); @@ -424,6 +477,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } + __host__ __device__ static auto + MakeBsGridDescriptor_BK0_N_BK1(const index_t K, + const index_t KPad, + const index_t N, + const index_t NPad, + const std::array& StrideBs, + const index_t BK0) + { + return generate_tuple( + [&](auto i) { + return MakeBGridDescriptor_BK0_N_BK1(K, KPad, N, NPad, StrideBs[i], BK0); + }, + Number{}); + } + template __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&) { @@ -557,7 +625,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // in some cases. else if constexpr(is_same::value) { - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeA) / APackedSize; constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( make_tuple( @@ -604,20 +672,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base constexpr auto KThreadRead = 64 / MPerWmma; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); constexpr auto KThreadReadPerm = (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) : KThreadRead; // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128) + constexpr auto mpair = (AK1Number * MPerWmma * sizeof(LDSTypeA) > 128) ? 1 - : ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0 + : ((128 / (AK1Number * MPerWmma * sizeof(LDSTypeA))) > M0 ? M0 - : 128 / (AK1Number * MPerWmma * sizeof(ADataType))); + : 128 / (AK1Number * MPerWmma * sizeof(LDSTypeA))); constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, @@ -694,7 +762,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base else if constexpr(is_same::value) { // NLdsLayer * K0 as logical Bank - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeB) / BPackedSize; constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( make_tuple( @@ -738,20 +806,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base constexpr auto KThreadRead = 64 / NPerWmma; constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) ? 1 - : 128 / (BK1Number * N0 * sizeof(BDataType)); + : 128 / (BK1Number * N0 * sizeof(LDSTypeB)); constexpr auto KThreadReadPerm = (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) : KThreadRead; // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128) + constexpr auto npair = (BK1Number * NPerWmma * sizeof(LDSTypeB) > 128) ? 1 - : ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0 + : ((128 / (BK1Number * NPerWmma * sizeof(LDSTypeB))) > N0 ? N0 - : 128 / (BK1Number * NPerWmma * sizeof(BDataType))); + : 128 / (BK1Number * NPerWmma * sizeof(LDSTypeB))); constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, @@ -836,8 +904,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, - ADataType, - BDataType, + LDSTypeA, + LDSTypeB, ComputeTypeA, ComputeTypeB, AccDataType, @@ -1120,11 +1188,24 @@ struct GridwiseGemm_wmma_cshuffle_v3_base c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat .GetElementSpaceSize(); - return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + - b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize + + b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize), c_block_size * sizeof(CShuffleDataType)); } + template + __device__ __forceinline__ static auto get_first_element_workaround(Type& array) + { + if constexpr(numElements > 1) + { + return array; + } + else + { + return array[I0]; + } + } + template - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run(AsGridPointer p_as_grid, + BsGridPointer p_bs_grid, DsGridPointer p_ds_grid, EDataType* p_e_grid, void* p_shared, - const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const AGridDesc_AK0_M_K1& as_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& bs_grid_desc_bk0_n_bk1, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& @@ -1152,10 +1233,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base const index_t& num_k_block_per_scale, BScaleStruct& b_scale_struct) { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + const auto as_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto bs_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize()); + }, + Number{}); + const auto ds_grid_buf = generate_tuple( [&](auto i) { return make_dynamic_buffer( @@ -1183,66 +1274,144 @@ struct GridwiseGemm_wmma_cshuffle_v3_base constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + // workaround because v7r2 is not as general as v4r1 + auto get_a_blockwise_transfer = [&]() { + if constexpr(NumATensor > 1) + { + const auto idx_as_block_begin = generate_tuple( + [&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, + Number{}); + + return ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + AsDataType, + Tuple, + AGridDesc_AK0_M_K1, + decltype(tie(a_block_desc_ak0_m_ak1)), + AElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + uniform_sequence_gen_t, + Sequence, + BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1, + idx_as_block_begin, + tie(a_block_desc_ak0_m_ak1), + make_tuple(make_multi_index(0, 0, 0)), + a_element_op}; + } + else + { + return ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + remove_cvref_t>, + remove_cvref_t>, + decltype(as_grid_desc_ak0_m_ak1[I0]), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + as_grid_desc_ak0_m_ak1[I0], + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + } + }; + + auto a_blockwise_copy = get_a_blockwise_transfer(); // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + // workaround because v7r2 is not as general as v4r1 + auto get_b_blockwise_transfer = [&]() { + if constexpr(NumBTensor > 1) + { + const auto idx_bs_block_begin = generate_tuple( + [&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, + Number{}); + + return ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + BsDataType, + Tuple, + BGridDesc_BK0_N_K1, + decltype(tie(b_block_desc_bk0_n_bk1)), + BElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + uniform_sequence_gen_t, + Sequence, + BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1, + idx_bs_block_begin, + tie(b_block_desc_bk0_n_bk1), + make_tuple(make_multi_index(0, 0, 0)), + b_element_op}; + } + else + { + return ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + remove_cvref_t>, + remove_cvref_t>, + decltype(bs_grid_desc_bk0_n_bk1[I0]), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + bs_grid_desc_bk0_n_bk1[I0], + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + } + }; + + auto b_blockwise_copy = get_b_blockwise_transfer(); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -1250,12 +1419,12 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // Cast after lds auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * - sizeof(ADataType) / - APackedSize), + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(LDSTypeA) / + APackedSize), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); @@ -1267,25 +1436,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) / KPerBlock); - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - c_thread_buf, - b_scale_struct, - num_k_block_main_loop, - num_k_block_per_scale); + blockwise_gemm_pipeline.template Run( + get_first_element_workaround(as_grid_desc_ak0_m_ak1), + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + get_first_element_workaround(as_grid_buf), + a_block_buf, + a_block_slice_copy_step, + get_first_element_workaround(bs_grid_desc_bk0_n_bk1), + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + get_first_element_workaround(bs_grid_buf), + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + b_scale_struct, + num_k_block_main_loop, + num_k_block_per_scale); // shuffle C and write out { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp index 6e2950180d..3ebfdfa0d3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,11 +17,229 @@ namespace tensor_operation { namespace device { namespace instance { -using Multiply = ck::tensor_operation::element_wise::Multiply; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; #ifdef CK_ENABLE_INT8 + +#ifdef CK_USE_WMMA +// RRR +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); + +// RCR +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); + +// CRR +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); + +// Multiply +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAddFastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAdd>>>& instances); + +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyFastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances); + +#endif + +#ifdef CK_USE_XDL // RRR void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( std::vector, @@ -198,7 +416,7 @@ void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_i void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( std::vector, ck::Tuple, - ck::Tuple, + ck::Tuple, Row, ck::Tuple, ck::Tuple, @@ -233,10 +451,88 @@ void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( PassThrough, PassThrough, Multiply>>>& instances); - +#endif #endif // GEMM + Add + Gelu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances( + op_ptrs); + } + } +#endif +#endif + + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -300,6 +597,27 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances(op_ptrs); } } +#endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA #endif return op_ptrs; @@ -307,6 +625,81 @@ struct DeviceOperationInstanceFactory< }; // GEMM + Add +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(op_ptrs); + } + } +#endif + +#endif + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -372,11 +766,107 @@ struct DeviceOperationInstanceFactory< } #endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA +#endif + return op_ptrs; } }; // GEMM + Gelu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(op_ptrs); + } + } +#endif +#endif + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -442,11 +933,106 @@ struct DeviceOperationInstanceFactory< } #endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif +#endif return op_ptrs; } }; // GEMM +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(op_ptrs); + } + } +#endif +#endif + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -511,13 +1098,95 @@ struct DeviceOperationInstanceFactory< } } #endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif +#endif return op_ptrs; } }; // Multiply // GEMM + Add + Gelu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( + op_ptrs); + } + } +#endif +#endif + + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -568,6 +1238,27 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif #endif return op_ptrs; @@ -575,6 +1266,67 @@ struct DeviceOperationInstanceFactory< }; // GEMM + Add +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + op_ptrs); + } + } +#endif +#endif + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -625,6 +1378,27 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif #endif return op_ptrs; @@ -632,6 +1406,68 @@ struct DeviceOperationInstanceFactory< }; // GEMM + Gelu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + op_ptrs); + } + } +#endif +#endif + + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -682,6 +1519,27 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif #endif return op_ptrs; @@ -689,6 +1547,67 @@ struct DeviceOperationInstanceFactory< }; // GEMM +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances(op_ptrs); + } + } +#endif +#endif + + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -740,6 +1660,28 @@ struct DeviceOperationInstanceFactory< } #endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif +#endif + return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt index 5af7322b1a..5ce585ad81 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt @@ -1,16 +1,26 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_MULTI_ABD_INSTANCES) list(APPEND GEMM_MULTI_ABD_INSTANCES - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - ) + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +) add_instance_library(device_gemm_multi_abd_instance ${GEMM_MULTI_ABD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp new file mode 100644 index 0000000000..8d4c45ae82 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = BF16; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using D0Layout = Row; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; +using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; +using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; +using Add = ck::tensor_operation::element_wise::Add; + +using AElementOp = PassThrough; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off + //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| + //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | + //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | + //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple< + // clang-format off + //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| + //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | + //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | + //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 32, 16, 16, 256, 8, 8, 16, 16, 1, 1, S<32, 1, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 2>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 64, 16, 32, 256, 8, 8, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..eef450533b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple<>, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple<>, + EDataType, + AElementOp, + Multiply, + PassThrough>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + PassThrough, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + PassThrough, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp new file mode 100644 index 0000000000..0c2a34fbf8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; +using Add = ck::tensor_operation::element_wise::Add; + +using AElementOp = PassThrough; +using BElementOp = Multiply; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// using CDEElementOp = AddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off + //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| + //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | + //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | + //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..30ab4135d9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + Multiply, + Add>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Multiply, + Add, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Multiply, + Add, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..56d30f9ad2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + Multiply, + AddFastGelu>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Multiply, + AddFastGelu, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Multiply, + AddFastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp new file mode 100644 index 0000000000..d4b9054a73 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances, + ck::Tuple, + AddFastGelu, + GemmMNKPadding, + Interwave>{}); +} + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances, + ck::Tuple, + Add, + GemmMNKPadding, + Interwave>{}); +} + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances, + ck::Tuple<>, + PassThrough, + GemmMNKPadding, + Interwave>{}); +} + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances, + ck::Tuple<>, + FastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..cfeaad1a66 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple<>, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple<>, + EDataType, + AElementOp, + Multiply, + FastGelu>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + FastGelu, + GemmMNKPadding, + Interwave>{}); + + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + FastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..fe36c30e75 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + Multiply, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + Multiply, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..69b0e6ff0b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + MultiplyAdd>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyAdd, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyAdd, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..a779f27f62 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + MultiplyAddFastGelu>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyAddFastGelu, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyAddFastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..dec51f72aa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + MultiplyFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyFastGelu, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyFastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp new file mode 100644 index 0000000000..a3c5c6a3ac --- /dev/null +++ b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp @@ -0,0 +1,424 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +// this function is also defined in CK but because of the way we use it in +// profile_gemm_multi_impl, it requires the arguments to not be const +template +auto concat_tuple_of_refs(ck::Tuple& tx, ck::Tuple& ty) +{ + return ck::unpack2( + [&](auto&&... zs) { return ck::Tuple{ck::forward(zs)...}; }, + tx, + ty); +} + +template +bool profile_gemm_multi_abd_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD, + int StrideE) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + static constexpr index_t NumATensor = AsDataType::Size(); + auto as_m_k = generate_tuple( + [&](auto i) { + using ADataType = remove_cvref_t>; + using ALayout = remove_cvref_t>; + + return Tensor(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + }, + Number{}); + + static constexpr index_t NumBTensor = BsDataType::Size(); + auto bs_k_n = generate_tuple( + [&](auto i) { + using BDataType = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + return Tensor(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + }, + Number{}); + + static constexpr index_t NumDTensor = DsDataType::Size(); + auto ds_m_n = generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + using DLayout = remove_cvref_t>; + + return Tensor(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + }, + Number{}); + + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + static_for<0, NumATensor, 1>{}( + [&](auto i) { std::cout << "a" << i.value << "_m_k: " << as_m_k(i).mDesc << std::endl; }); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { std::cout << "b" << i.value << "_k_n: " << bs_k_n(i).mDesc << std::endl; }); + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "d" << i.value << "_m_n: " << ds_m_n(i).mDesc << std::endl; }); + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + + as_m_k(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + + bs_k_n(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + ds_m_n(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }); + + break; + default: + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + + as_m_k(i).GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + + bs_k_n(i).GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + ds_m_n(i).GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + }); + } + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + Tensor c_m_n({M, N}); + + using AComputeType = + typename std::conditional<(NumATensor > 1), + EDataType, + remove_cvref_t>>::type; + + auto get_a_matrix = [&]() -> auto { + // in case of pass through we avoid allocating a new + // tensor and copying values + if constexpr(is_same_v) + { + return as_m_k(Number<0>{}); + } + else + { + Tensor a_m_k({M, K}); + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + // result + auto data_refs1 = ck::tie(a_m_k(m, k)); + // inputs + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return as_m_k(Number{})(m, k); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(a_element_op, data_refs); + } + } + return a_m_k; + } + }; + + using BComputeType = + typename std::conditional<(NumBTensor > 1), + EDataType, + remove_cvref_t>>::type; + + auto get_b_matrix = [&]() -> auto { + // in case of pass through we avoid allocating a new + // tensor and copying values + if constexpr(is_same_v) + { + return bs_k_n(Number<0>{}); + } + else + { + Tensor b_k_n({K, N}); + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N; ++n) + { + // result + auto data_refs1 = ck::tie(b_k_n(k, n)); + // inputs + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return bs_k_n(Number{})(k, n); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(b_element_op, data_refs); + } + } + return b_k_n; + } + }; + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + get_a_matrix(), get_b_matrix(), c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + // compulsory + auto data_refs1 = ck::tie(e_m_n_host_result(m, n), c_m_n(m, n)); + // optional (if multiple Ds) + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return ds_m_n(Number{})(m, n); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(cde_element_op, data_refs); + } + } + } + + std::array as_device_buf; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + as_device_buf[i] = new DeviceMem(sizeof(ADataType) * as_m_k(i).mDesc.GetElementSpaceSize()); + }); + + std::array bs_device_buf; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + bs_device_buf[i] = new DeviceMem(sizeof(BDataType) * bs_k_n(i).mDesc.GetElementSpaceSize()); + }); + + std::array ds_device_buf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_device_buf[i] = new DeviceMem(sizeof(DDataType) * ds_m_n(i).mDesc.GetElementSpaceSize()); + }); + + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + static_for<0, NumATensor, 1>{}( + [&](auto i) { as_device_buf[i]->ToDevice(as_m_k(i).mData.data()); }); + + static_for<0, NumBTensor, 1>{}( + [&](auto i) { bs_device_buf[i]->ToDevice(bs_k_n(i).mData.data()); }); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_device_buf[i]->ToDevice(ds_m_n(i).mData.data()); }); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + std::array as_pointer; + std::array as_stride; + static_for<0, NumATensor, 1>{}([&](auto i) { + as_pointer[i] = as_device_buf[i]->GetDeviceBuffer(); + as_stride[i] = StrideA; + }); + + std::array bs_pointer; + std::array bs_stride; + static_for<0, NumBTensor, 1>{}([&](auto i) { + bs_pointer[i] = bs_device_buf[i]->GetDeviceBuffer(); + bs_stride[i] = StrideB; + }); + std::array ds_pointer; + std::array ds_stride; + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_pointer[i] = ds_device_buf[i]->GetDeviceBuffer(); + ds_stride[i] = StrideD; + }); + + auto argument_ptr = op_ptr->MakeArgumentPointer(as_pointer, + bs_pointer, + ds_pointer, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + as_stride, + bs_stride, + ds_stride, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t sizeADataType = 0; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + sizeADataType = std::max(sizeADataType, sizeof(ADataType)); + }); + std::size_t sizeBDataType = 0; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + sizeBDataType = std::max(sizeBDataType, sizeof(BDataType)); + }); + + std::size_t num_btype = + sizeADataType * M * K + sizeBDataType * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + static_for<0, NumATensor, 1>{}([&](auto i) { delete as_device_buf[i]; }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { delete bs_device_buf[i]; }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { delete ds_device_buf[i]; }); + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 31f684fe75..c31ede2c73 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -96,6 +96,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) + list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp) @@ -234,6 +235,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) + list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance) list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance) diff --git a/profiler/src/profile_gemm_multi_abd.cpp b/profiler/src/profile_gemm_multi_abd.cpp new file mode 100644 index 0000000000..157bcbc977 --- /dev/null +++ b/profiler/src/profile_gemm_multi_abd.cpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_multi_abd_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + BF16_I8_BF16_BF16, // 0 +}; + +enum struct GemmElementOp +{ + PASS_THROUGH, // 0 + MULTIPLY, // 1 + ADD, // 2 + FASTGELU, // 3 + ADD_FASTGELU, // 4 + MULTIPLY_ADD, // 5 + MULTIPLY_FASTGELU, // 6 + MULTIPLY_ADD_FASTGELU, // 7 +}; + +#define OP_NAME "gemm_multi_abd" +#define OP_DESC "GEMM_Multiple_ABD" + +int profile_gemm_multi_abd(int argc, char* argv[]) +{ + if(argc != 18) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: bf16@int8/bf16->bf16;)\n"); + printf("arg3: matrix layout (0: E[m, n] = A[m, k] * B[k, n];\n"); + printf(" 1: E[m, n] = A[m, k] * B[n, k];\n"); + printf(" 2: E[m, n] = A[k, m] * B[k, n];\n"); + printf(" 3: E[m, n] = A[k, m] * B[n, k])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8: number of As (1)\n"); + printf("arg9: number of Bs (1/2)\n"); + printf("arg10: number of Ds (0/1/2)\n"); + printf("arg11 to 17: M, N, K, StrideA, StrideB, StrideE, StrideD\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int num_as = std::stoi(argv[8]); + const int num_bs = std::stoi(argv[9]); + const int num_ds = std::stoi(argv[10]); + + const int M = std::stoi(argv[11]); + const int N = std::stoi(argv[12]); + const int K = std::stoi(argv[13]); + + const int StrideA = std::stoi(argv[14]); + const int StrideB = std::stoi(argv[15]); + const int StrideE = std::stoi(argv[16]); + const int StrideD = std::stoi(argv[17]); + + using F32 = float; + using BF16 = ck::bhalf_t; + using I8 = int8_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using Multiply = ck::tensor_operation::element_wise::Multiply; + using FastGelu = ck::tensor_operation::element_wise::FastGelu; + using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + + auto profile = [&](auto b_layout, auto b_element_op, auto cde_element_op, auto num_d_tensor) { + using ADataType = BF16; + using B0DataType = I8; + using B1DataType = BF16; + using DDataType = BF16; + using EDataType = BF16; + + using ALayout = Row; + using BLayout = decltype(b_layout); + using DLayout = Row; + using ELayout = Row; + + using AElementOp = PassThrough; + using BElementOp = decltype(b_element_op); + using CDEElementOp = decltype(cde_element_op); + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + constexpr auto NumberDTensor = decltype(num_d_tensor){}; + + // Only num_d_tensor == 0 and 1 are supported + using DsDataType = typename std:: + conditional<(NumberDTensor == 0), ck::Tuple<>, ck::Tuple>::type; + using DsLayout = + typename std::conditional<(NumberDTensor == 0), ck::Tuple<>, ck::Tuple>::type; + + bool pass = ck::profiler::profile_gemm_multi_abd_impl, + ck::Tuple, + F32, + DsDataType, + EDataType, + ck::Tuple, + ck::Tuple, + DsLayout, + ELayout, + AElementOp, + BElementOp, + CDEElementOp>( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD < 0) ? DefaultStrideD : StrideD, + (StrideE < 0) ? DefaultStrideE : StrideE); + + return pass ? 0 : 1; + }; + + // num_as == 1 is only supported + if(data_type != GemmDataType::BF16_I8_BF16_BF16 || num_as != 1) + { + std::cout << "The provided input parameters are not supported" << std::endl; + return 1; + } + + // Supported configurations + if(layout == GemmMatrixLayout::MK_KN_MN && num_bs == 2 && num_ds == 1) + { + return profile(Row{}, Multiply{}, AddFastGelu{}, ck::Number<1>{}); + } + else if(layout == GemmMatrixLayout::MK_KN_MN && num_bs == 2 && num_ds == 0) + { + return profile(Row{}, Multiply{}, FastGelu{}, ck::Number<0>{}); + } + else if(layout == GemmMatrixLayout::MK_NK_MN && num_bs == 2 && num_ds == 1) + { + return profile(Col{}, Multiply{}, AddFastGelu{}, ck::Number<1>{}); + } + else if(layout == GemmMatrixLayout::MK_NK_MN && num_bs == 2 && num_ds == 0) + { + return profile(Col{}, Multiply{}, FastGelu{}, ck::Number<0>{}); + } + + std::cout << "The provided input parameters are not supported" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_multi_abd); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index cedac568db..df3a03cca8 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -243,6 +243,7 @@ add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) add_subdirectory(gemm_add) add_subdirectory(gemm_layernorm) +add_subdirectory(gemm_multi_abd) add_subdirectory(gemm_split_k) add_subdirectory(gemm_universal) add_subdirectory(gemm_b_scale) diff --git a/test/gemm_multi_abd/CMakeLists.txt b/test/gemm_multi_abd/CMakeLists.txt new file mode 100644 index 0000000000..d700414b05 --- /dev/null +++ b/test/gemm_multi_abd/CMakeLists.txt @@ -0,0 +1,9 @@ +add_gtest_executable(test_gemm_multi_abd_wmma test_gemm_multi_abd_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_multi_abd_wmma PRIVATE utility device_gemm_multi_abd_instance) +endif() + +add_gtest_executable(test_gemm_multi_abd_xdl test_gemm_multi_abd_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_multi_abd_xdl PRIVATE utility device_gemm_multi_abd_instance) +endif() diff --git a/test/gemm_multi_abd/test_gemm_common.hpp b/test/gemm_multi_abd/test_gemm_common.hpp new file mode 100644 index 0000000000..030fbcac77 --- /dev/null +++ b/test/gemm_multi_abd/test_gemm_common.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using F32 = float; + +template +class TestGemmCommon : public ::testing::Test +{ + protected: + using AsLayout = std::tuple_element_t<0, Tuple>; + using BsLayout = std::tuple_element_t<1, Tuple>; + using DsLayout = std::tuple_element_t<2, Tuple>; + using ELayout = Row; + using AsDataType = std::tuple_element_t<3, Tuple>; + using BsDataType = std::tuple_element_t<4, Tuple>; + using DsDataType = std::tuple_element_t<5, Tuple>; + using EDataType = std::tuple_element_t<6, Tuple>; + using AElementOp = std::tuple_element_t<7, Tuple>; + using BElementOp = std::tuple_element_t<8, Tuple>; + using CDEElementOp = std::tuple_element_t<9, Tuple>; + + void Run() + { + std::vector> lengths = { + {16, 32, 64}, {512, 1024, 2048}, {1024, 512, 32}}; + + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + // Assuming same layout for all A matrices (same applies for Bs and Ds) + int StrideA = ck::is_same_v>, Row> ? K : M; + int StrideB = ck::is_same_v>, Row> ? N : K; + // In case no D matrices are provided, set stride to 0 + int StrideD = 0; + if constexpr(DsDataType::Size() > 0) + { + StrideD = ck::is_same_v>, Row> ? N : M; + } + int StrideE = ck::is_same_v ? N : M; + + all_success = + all_success & ck::profiler::profile_gemm_multi_abd_impl( + 1, 2, false, false, M, N, K, StrideA, StrideB, StrideD, StrideE); + } + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp b/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp new file mode 100644 index 0000000000..42584ecc02 --- /dev/null +++ b/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_multi_abd_impl.hpp" +#include "test_gemm_common.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using I8 = int8_t; +using BF16 = ck::bhalf_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Multiply = ck::tensor_operation::element_wise::Multiply; +using Add = ck::tensor_operation::element_wise::Add; +using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; +using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; + +using KernelTypesABD = ::testing::Types, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAdd>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>; + +TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD); +TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); } + +} // namespace test +} // namespace ck diff --git a/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp b/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp new file mode 100644 index 0000000000..42584ecc02 --- /dev/null +++ b/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_multi_abd_impl.hpp" +#include "test_gemm_common.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using I8 = int8_t; +using BF16 = ck::bhalf_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Multiply = ck::tensor_operation::element_wise::Multiply; +using Add = ck::tensor_operation::element_wise::Add; +using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; +using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; + +using KernelTypesABD = ::testing::Types, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAdd>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>; + +TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD); +TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); } + +} // namespace test +} // namespace ck