From 37bbf7a46a3a1ecb6e5465d059c993a5e68a63b8 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Mon, 15 Apr 2024 21:09:45 -0500 Subject: [PATCH] Added Multi_ABD support into Gemm and GroupedGemmFixedNK (#978) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * added an example grouped_gemm_multi_abd * fixed ci * add setElementwiseOp * changed API * clean code: add multiA into example * fixed v7r2 copy * add transpose * clean * fixed vector_load check * Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp Co-authored-by: Bartłomiej Kocot * Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp Co-authored-by: Bartłomiej Kocot * Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp Co-authored-by: Bartłomiej Kocot * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp Co-authored-by: Bartłomiej Kocot * Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp Co-authored-by: Bartłomiej Kocot * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Co-authored-by: Bartłomiej Kocot * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp Co-authored-by: Bartłomiej Kocot * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp Co-authored-by: Bartłomiej Kocot * add reduce * testing * add example_b16_i8 * refactor example * clean * add mpading * disable reduce for kbatch = 1 * seperate reduce device op * add reduce op * add guard for workspace_size * add instances * format * fixed * add client example * add a colmajor * add instances * Update cmake-ck-dev.sh * Update profile_gemm_splitk.cpp * Update gridwise_gemm_xdlops_v2r4r2.hpp * format * Update profile_gemm_splitk.cpp * fixed * fixed * adjust test * adjust precision loss * adjust test * fixed * add bf16_i8 scale bias * fixed scale * fixed scale elementwise_op * revert contraction deviceop changes * fixed * Add AddFastGelu * Revert "Merge branch 'jizhan/gemm_splitk_reduce' into grouped_gemm_multi_abd_fixed_nk_example" This reverts commit 3b5d001efd74335b38dcb7d8c8877580b49d23a4, reversing changes made to 943199a99191661c5597c51ca8371a90bf57837e. * add Scales into elementwise * add gemm_multi_abd client example * add client examples * add rcr and crr * add grouped gemm client example * add grouped gemm client example * add instance for rcr crr * format * fixed * fixed cmake * fixed * fixed client_example * format * fixed contraction isSupport * Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp Co-authored-by: Bartłomiej Kocot * Update device_reduce_threadwise.hpp * clean * Fixes * Fix example --------- Co-authored-by: Jing Zhang Co-authored-by: Bartłomiej Kocot [ROCm/composable_kernel commit: 12865fbf285ce4d9890845b8d14a9418525bf439] --- .../30_gemm_multi_abd/CMakeLists.txt | 13 + .../gemm_bias_fastgelu_xdl_bf16_i8.cpp | 262 ++++++ .../gemm_bias_xdl_bf16_i8.cpp | 262 ++++++ .../30_gemm_multi_abd/gemm_xdl_bf16_i8.cpp | 257 ++++++ .../gemm_xdl_gelu_bf16_i8.cpp | 261 ++++++ .../31_grouped_gemm_multi_abd/CMakeLists.txt | 7 + ...grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp | 286 ++++++ .../grouped_gemm_fastgelu_xdl_bf16_i8.cpp | 282 ++++++ .../59_grouped_gemm_multi_ABD/CMakeLists.txt | 7 + ...mm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp | 401 +++++++++ ..._gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp | 397 ++++++++ example/60_gemm_multi_ABD/CMakeLists.txt | 1 + .../gemm_multi_ABD_xdl_bf16_i8.cpp | 270 ++++++ .../gemm_multi_ABD_xdl_fp16.cpp | 12 +- .../contraction_multi_ABD_xdl_fp16.cpp | 4 +- .../device/device_grouped_gemm_multi_abd.hpp | 98 ++ ...device_grouped_gemm_multi_abd_fixed_nk.hpp | 81 ++ ..._contraction_multiple_abd_xdl_cshuffle.hpp | 12 +- .../device_gemm_multiple_abd_xdl_cshuffle.hpp | 106 +-- ...ce_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 851 ++++++++++++++++++ .../device/impl/device_reduce_threadwise.hpp | 1 - .../element/binary_element_wise_operation.hpp | 24 +- .../combined_element_wise_operation.hpp | 9 + .../element/unary_element_wise_operation.hpp | 37 +- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 13 +- .../threadwise_tensor_slice_transfer.hpp | 33 +- .../threadwise_tensor_slice_transfer_util.hpp | 67 ++ .../threadwise_tensor_slice_transfer_v3r1.hpp | 37 +- .../threadwise_tensor_slice_transfer_v7r2.hpp | 247 ++++- .../gpu/gemm_multi_abd.hpp | 468 ++++++++++ .../gpu/grouped_gemm_multi_abd_fixed_nk.hpp | 470 ++++++++++ .../gpu/gemm_multi_abd/CMakeLists.txt | 10 + ...multi_abd_bf16_i8_bf16_km_kn_mn_common.hpp | 101 +++ ...multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp | 101 +++ ...multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp | 101 +++ ...gelu_bf16_i8_bf16_km_kn_mn_v1_instance.cpp | 115 +++ ...gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp | 115 +++ ...gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp | 115 +++ .../CMakeLists.txt | 10 + ..._fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp | 89 ++ ..._fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp | 89 ++ ..._fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp | 89 ++ ...as_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp | 111 +++ ...as_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp | 111 +++ ...as_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp | 111 +++ 45 files changed, 6345 insertions(+), 199 deletions(-) create mode 100644 client_example/30_gemm_multi_abd/CMakeLists.txt create mode 100644 client_example/30_gemm_multi_abd/gemm_bias_fastgelu_xdl_bf16_i8.cpp create mode 100644 client_example/30_gemm_multi_abd/gemm_bias_xdl_bf16_i8.cpp create mode 100644 client_example/30_gemm_multi_abd/gemm_xdl_bf16_i8.cpp create mode 100644 client_example/30_gemm_multi_abd/gemm_xdl_gelu_bf16_i8.cpp create mode 100644 client_example/31_grouped_gemm_multi_abd/CMakeLists.txt create mode 100644 client_example/31_grouped_gemm_multi_abd/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp create mode 100644 client_example/31_grouped_gemm_multi_abd/grouped_gemm_fastgelu_xdl_bf16_i8.cpp create mode 100644 example/59_grouped_gemm_multi_ABD/CMakeLists.txt create mode 100644 example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp create mode 100644 example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp create mode 100644 example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bf16_i8.cpp create mode 100644 include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp create mode 100644 include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp create mode 100644 include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_common.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_km_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp diff --git a/client_example/30_gemm_multi_abd/CMakeLists.txt b/client_example/30_gemm_multi_abd/CMakeLists.txt new file mode 100644 index 0000000000..4d85c68400 --- /dev/null +++ b/client_example/30_gemm_multi_abd/CMakeLists.txt @@ -0,0 +1,13 @@ +if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf16") OR NOT DEFINED DTYPES)) + add_executable(client_gemm_bias_fastgelu_bf16_i8_bf16 gemm_bias_fastgelu_xdl_bf16_i8.cpp) + target_link_libraries(client_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_gemm_bias_bf16_i8_bf16 gemm_bias_xdl_bf16_i8.cpp) + target_link_libraries(client_gemm_bias_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_gemm_gelu_bf16_i8_bf16 gemm_xdl_gelu_bf16_i8.cpp) + target_link_libraries(client_gemm_gelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_gemm_bf16_i8_bf16 gemm_xdl_bf16_i8.cpp) + target_link_libraries(client_gemm_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/30_gemm_multi_abd/gemm_bias_fastgelu_xdl_bf16_i8.cpp b/client_example/30_gemm_multi_abd/gemm_bias_fastgelu_xdl_bf16_i8.cpp new file mode 100644 index 0000000000..486cdf74dd --- /dev/null +++ b/client_example/30_gemm_multi_abd/gemm_bias_fastgelu_xdl_bf16_i8.cpp @@ -0,0 +1,262 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Scales; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +// clang-format on +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 64; + ck::index_t N = 1024; + ck::index_t K = 512; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 8) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD = std::stoi(argv[6]); + StrideE = std::stoi(argv[7]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a0_device_buf(sizeof(A0DataType) * + f_matrix_space_size(M, K, StrideA, A0Layout{})); + SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * + f_matrix_space_size(K, N, StrideB, B0Layout{})); + SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{})); + SimpleDeviceMem d0_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD, ELayout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{d0_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{d0_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/30_gemm_multi_abd/gemm_bias_xdl_bf16_i8.cpp b/client_example/30_gemm_multi_abd/gemm_bias_xdl_bf16_i8.cpp new file mode 100644 index 0000000000..8f47cb143e --- /dev/null +++ b/client_example/30_gemm_multi_abd/gemm_bias_xdl_bf16_i8.cpp @@ -0,0 +1,262 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Col; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using AElementOp = PassThrough; +using BElementOp = Scales; +using CDEElementOp = Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +// clang-format on +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 64; + ck::index_t N = 1024; + ck::index_t K = 512; + + ck::index_t StrideA = M; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 8) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD = std::stoi(argv[6]); + StrideE = std::stoi(argv[7]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a0_device_buf(sizeof(A0DataType) * + f_matrix_space_size(M, K, StrideA, A0Layout{})); + SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * + f_matrix_space_size(K, N, StrideB, B0Layout{})); + SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{})); + SimpleDeviceMem d0_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD, ELayout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{d0_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{d0_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/30_gemm_multi_abd/gemm_xdl_bf16_i8.cpp b/client_example/30_gemm_multi_abd/gemm_xdl_bf16_i8.cpp new file mode 100644 index 0000000000..f2e5d6187d --- /dev/null +++ b/client_example/30_gemm_multi_abd/gemm_xdl_bf16_i8.cpp @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using AElementOp = PassThrough; +using BElementOp = Scales; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +// clang-format on +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 64; + ck::index_t N = 1024; + ck::index_t K = 512; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideE = std::stoi(argv[6]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a0_device_buf(sizeof(A0DataType) * + f_matrix_space_size(M, K, StrideA, A0Layout{})); + SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * + f_matrix_space_size(K, N, StrideB, B0Layout{})); + SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 0; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/30_gemm_multi_abd/gemm_xdl_gelu_bf16_i8.cpp b/client_example/30_gemm_multi_abd/gemm_xdl_gelu_bf16_i8.cpp new file mode 100644 index 0000000000..4e2ada1295 --- /dev/null +++ b/client_example/30_gemm_multi_abd/gemm_xdl_gelu_bf16_i8.cpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = Scales; +using CDEElementOp = FastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +// clang-format on +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 64; + ck::index_t N = 1024; + ck::index_t K = 512; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideE = std::stoi(argv[6]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a0_device_buf(sizeof(A0DataType) * + f_matrix_space_size(M, K, StrideA, A0Layout{})); + SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * + f_matrix_space_size(K, N, StrideB, B0Layout{})); + SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 0; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/31_grouped_gemm_multi_abd/CMakeLists.txt b/client_example/31_grouped_gemm_multi_abd/CMakeLists.txt new file mode 100644 index 0000000000..c4303d622f --- /dev/null +++ b/client_example/31_grouped_gemm_multi_abd/CMakeLists.txt @@ -0,0 +1,7 @@ +if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf16") OR NOT DEFINED DTYPES)) + add_executable(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp) + target_link_libraries(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) + + add_executable(client_grouped_gemm_fastgelu_bf16_i8_bf16 grouped_gemm_fastgelu_xdl_bf16_i8.cpp) + target_link_libraries(client_grouped_gemm_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/31_grouped_gemm_multi_abd/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_multi_abd/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp new file mode 100644 index 0000000000..f5ed713281 --- /dev/null +++ b/client_example/31_grouped_gemm_multi_abd/grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Scales; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, b0_tensors_device, b1_tensors_device, + d0_tensors_device, c_tensors_device; + + a0_tensors_device.reserve(group_count); + b0_tensors_device.reserve(group_count); + b1_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + } + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + + b0_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + b1_tensors_device.emplace_back( + std::make_unique(sizeof(B1DataType) * problem_size.Ns[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + gemm_descs.push_back( + {sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {0}, 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer()}, + std::array{b0_tensors_device[i]->GetDeviceBuffer(), + b1_tensors_device[i]->GetDeviceBuffer()}, + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i], 0}, + std::array{0}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + auto argument_ptr = op_ptr->MakeArgumentPointer(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + SimpleDeviceMem gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetElementwiseOps( + argument_ptr.get(), a_element_op, b_element_op, cde_element_op); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * sum_of_m * problem_size.Ns[0] * problem_size.Ks[0]; + + std::size_t num_btype = sizeof(A0DataType) * sum_of_m * problem_size.Ks[0] + + sizeof(B0DataType) * problem_size.Ks[0] * problem_size.Ns[0] + + sizeof(EDataType) * sum_of_m * problem_size.Ns[0]; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return true; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(32 + rand() % 32); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(512); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ns[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/client_example/31_grouped_gemm_multi_abd/grouped_gemm_fastgelu_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_multi_abd/grouped_gemm_fastgelu_xdl_bf16_i8.cpp new file mode 100644 index 0000000000..c2110fbd51 --- /dev/null +++ b/client_example/31_grouped_gemm_multi_abd/grouped_gemm_fastgelu_xdl_bf16_i8.cpp @@ -0,0 +1,282 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Col; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = Scales; +using CDEElementOp = FastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, b0_tensors_device, b1_tensors_device, + c_tensors_device; + + a0_tensors_device.reserve(group_count); + b0_tensors_device.reserve(group_count); + b1_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + } + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 0; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + + b0_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + b1_tensors_device.emplace_back( + std::make_unique(sizeof(B1DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + gemm_descs.push_back( + {sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {}, 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer()}, + std::array{b0_tensors_device[i]->GetDeviceBuffer(), + b1_tensors_device[i]->GetDeviceBuffer()}, + std::array{}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i], 0}, + std::array{}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + auto argument_ptr = op_ptr->MakeArgumentPointer(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + SimpleDeviceMem gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetElementwiseOps( + argument_ptr.get(), a_element_op, b_element_op, cde_element_op); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * sum_of_m * problem_size.Ns[0] * problem_size.Ks[0]; + + std::size_t num_btype = sizeof(A0DataType) * sum_of_m * problem_size.Ks[0] + + sizeof(B0DataType) * problem_size.Ks[0] * problem_size.Ns[0] + + sizeof(EDataType) * sum_of_m * problem_size.Ns[0]; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return true; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(32 + rand() % 32); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(512); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ns[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt new file mode 100644 index 0000000000..78f6832895 --- /dev/null +++ b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt @@ -0,0 +1,7 @@ +add_custom_target(example_grouped_gemm_xdl_multi_abd) + +add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16) + +add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8) diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp new file mode 100644 index 0000000000..fad53eb514 --- /dev/null +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp @@ -0,0 +1,401 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Scales; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>; + +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a0_tensors; + std::vector> b_tensors; + std::vector> b0_tensors; + std::vector> b1_tensors; + std::vector> d0_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a0_tensors.reserve(group_count); + b_tensors.reserve(group_count); + b0_tensors.reserve(group_count); + b1_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, b0_tensors_device, b1_tensors_device, + d0_tensors_device, c_tensors_device; + + a0_tensors_device.reserve(group_count); + b0_tensors_device.reserve(group_count); + b1_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + + a0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); + + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + b0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + b1_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ks[i], problem_size.Ns[i], 0, B1Layout{}))); + + d0_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); + + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + + std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc + << " b_k_n: " << b0_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " c_m_n: " << c_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() + + sizeof(B0DataType) * b0_tensors[i].mDesc.GetElementSize() + + sizeof(B1DataType) * b1_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_2{0, 5}); + break; + case 2: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-5, 5}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + + b0_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + b1_tensors_device.emplace_back( + std::make_unique(sizeof(B1DataType) * problem_size.Ns[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(), + a0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A0DataType)); + + b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data(), + b0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B0DataType)); + + b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data(), + b1_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B1DataType)); + + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); + + gemm_descs.push_back( + {sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {0}, 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer()}, + std::array{b0_tensors_device[i]->GetDeviceBuffer(), + b1_tensors_device[i]->GetDeviceBuffer()}, + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i], 0}, + std::array{0}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + // do GEMM + auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + for(int k = 0; k < problem_size.Ks[i]; ++k) + { + b_element_op(b_tensors[i](k, n), b0_tensors[i](k, n), b1_tensors[i](k, n)); + } + } + + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i], + b_tensors[i], + c_host_tensors[i], + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(32 + rand() % 32); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(512); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp new file mode 100644 index 0000000000..809c1a956c --- /dev/null +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp @@ -0,0 +1,397 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; +using Scale = ck::tensor_operation::element_wise::Scale; +using AddScale = ck::tensor_operation::element_wise::BinaryWithUnaryCombinedOp; + +using A0DataType = F16; +using A1DataType = F32; +using AsDataType = ck::Tuple; +using B0DataType = F16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using A0Layout = Row; +using A1Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = AddScale; +using BElementOp = PassThrough; + +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ck::half_t>; + +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a0_tensors; + std::vector> a1_tensors; + std::vector> b_tensors; + std::vector> d0_tensors; + std::vector> e_host_tensors; + std::vector> e_device_tensors; + + a0_tensors.reserve(group_count); + a1_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + e_host_tensors.reserve(group_count); + e_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a0_tensors_device, a1_tensors_device, b_tensors_device, + d0_tensors_device, c_tensors_device; + + a0_tensors_device.reserve(group_count); + a1_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{}))); + a1_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A1Layout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{}))); + d0_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); + e_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + e_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " c_m_n: " << e_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() + + sizeof(A1DataType) * a1_tensors[i].mDesc.GetElementSize() + + sizeof(B0DataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + constexpr ck::index_t NumATensor = 2; + constexpr ck::index_t NumBTensor = 1; + constexpr ck::index_t NumDTensor = 1; + + using GroupedGemmKernelArgument = ck::tensor_operation::device:: + GroupedGemmMultiABDKernelArgument; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a0_tensors_device.emplace_back( + std::make_unique(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i])); + + a1_tensors_device.emplace_back( + std::make_unique(sizeof(A1DataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(), + a0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A0DataType)); + + a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data(), + a1_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(A1DataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(B0DataType)); + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + {1, 1}, + {problem_size.stride_Bs[i]}, + {0}, + 1}); + + grouped_gemm_kernel_args_.push_back( + {std::array{a0_tensors_device[i]->GetDeviceBuffer(), + a1_tensors_device[i]->GetDeviceBuffer()}, + std::array{b_tensors_device[i]->GetDeviceBuffer()}, + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + std::array{problem_size.stride_As[i], + problem_size.stride_As[i]}, + std::array{problem_size.stride_Bs[i]}, + std::array{0}, + problem_size.stride_Cs[i]}); + } + + constexpr float scale = 1.f; + auto a_element_op = AElementOp{Add{}, Scale{scale}, Scale{scale}}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector> p_As = {}; + std::vector> p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + // do GEMM + auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int k = 0; k < problem_size.Ks[i]; ++k) + { + a_element_op(a0_tensors[i](m, k), a0_tensors[i](m, k), a1_tensors[i](m, k)); + } + } + + c_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data(), + e_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i], + b_tensors[i], + e_host_tensors[i], + PassThrough{}, + b_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + e_host_tensors[i](m, n), e_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(e_device_tensors[i], e_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(32 + rand() % 32); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(512); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/60_gemm_multi_ABD/CMakeLists.txt b/example/60_gemm_multi_ABD/CMakeLists.txt index d3974897fe..91e1f8009d 100644 --- a/example/60_gemm_multi_ABD/CMakeLists.txt +++ b/example/60_gemm_multi_ABD/CMakeLists.txt @@ -1 +1,2 @@ add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) +add_example_executable(example_gemm_multi_ABD_xdl_bf16_i8 gemm_multi_ABD_xdl_bf16_i8.cpp) \ No newline at end of file diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bf16_i8.cpp new file mode 100644 index 0000000000..7693956a75 --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bf16_i8.cpp @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Scales; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>; + +// clang-format on +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 64; + ck::index_t N = 1024; + ck::index_t K = 512; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, 0}, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp index 13ba639814..93034a8b70 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp @@ -37,7 +37,7 @@ using DDataType = F16; using EDataType = F16; using ALayout = Row; -using BLayout = Col; +using BLayout = Row; using DLayout = Row; using ELayout = Row; @@ -141,9 +141,9 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 1, 2, 8, - 8, 1, 1, 1, @@ -161,10 +161,10 @@ int main(int argc, char* argv[]) ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideD = 4096; - ck::index_t StrideE = 4096; + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; float alpha = 1.0f; float beta = 1.0f; diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index 4317484336..8b88e2482d 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -102,7 +102,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultiple S<1, 0, 2>, S<1, 0, 2>, 2, - 8, + 1, 8, 1, S<4, 64, 1>, @@ -131,7 +131,7 @@ int main(int argc, char* argv[]) std::vector a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1}; // A1[M1, K1] -> A1[M0, M1, K0, K1] std::vector a1_ms_ks_lengths{30, 128, 32, 64}; - std::vector a1_ms_ks_strides{0, 64, 0, 1}; + std::vector a1_ms_ks_strides{0, 64, 1, 0}; // B[N0, N1, K0, K1] std::vector b_ns_ks_lengths{32, 64, 32, 64}; std::vector b_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1}; diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp new file mode 100644 index 0000000000..59483cb89f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct GemmMultiABDDesc +{ + ck::index_t M_, N_, K_; + + std::vector stride_As_; + std::vector stride_Bs_; + std::vector stride_Ds_; + + ck::index_t stride_C_; +}; + +/* + * \brief Grouped Gemm Multi ABD + * + * C = a_op(A, A1...) * b_op(B, B1...) + * E = cde_op(C, D0, D1, ...) + * + * \tparam AsLayout A layouts (tuple). + * \tparam BsLayout B layouts (tuple). + * \tparam DsLayout Ds layouts (tuple). + * \tparam ELayout Output layout. + * \tparam AsDataType A data types (tuple). + * \tparam BsDataType B data types (tuple). + * \tparam DsDataType D data types (tuple). + * \tparam EDataType Output data type. + * \tparam AElementwiseOperation A elementwise operation. + * \tparam BElementwiseOperation B elementwise operation. + * \tparam CDEElementwiseOperation C elementwise operation. + */ +template +struct DeviceGroupedGemmMultiABD : public BaseOperator +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(AsLayout::Size() == AsDataType::Size(), "wrong! inconsistent NumATensor"); + static_assert(BsLayout::Size() == BsDataType::Size(), "wrong! inconsistent NumBTensor"); + static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor"); + + /* + * \brief Make argument pointer for grouped gemm multi abd. + * + * \param p_as A pointers to the A. + * \param p_bs A pointers to the B. + * \param p_ds A pointers to the Ds. + * \param p_e A pointers to the E. + * \param gemm_desc Gemm descriptors for each group. + * \param a_element_op A elementwise operation object. + * \param b_element_op B elementwise operation object. + * \param cde_element_op CDE elementwise operation object. + * \return Pointer to the argument. + */ + virtual std::unique_ptr + MakeArgumentPointer(std::vector>& p_as, + std::vector>& p_bs, + std::vector>& p_ds, + std::vector& p_e, + std::vector& gemm_desc, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual void SetElementwiseOps(BaseArgument* p_arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp new file mode 100644 index 0000000000..05c3e796a8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_grouped_gemm_multi_abd.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct GroupedGemmMultiABDKernelArgument +{ + std::array p_as_grid; + std::array p_bs_grid; + std::array p_ds_grid; + void* p_e_grid; + + index_t M; + index_t N; + index_t K; + + std::array StrideAs; + std::array StrideBs; + std::array StrideDs; + index_t StrideE; +}; + +/* + * \brief Grouped Gemm Multi ABD Fixed NK + * + * C = a_op(A, A1...) * b_op(B, B1...) + * E = cde_op(C, D0, D1, ...) + * + * \tparam AsLayout A layouts (tuple). + * \tparam BsLayout B layouts (tuple). + * \tparam DsLayout Ds layouts (tuple). + * \tparam ELayout Output layout. + * \tparam AsDataType A data types (tuple). + * \tparam BsDataType B data types (tuple). + * \tparam DsDataType D data types (tuple). + * \tparam EDataType Output data type. + * \tparam AElementwiseOperation A elementwise operation. + * \tparam BElementwiseOperation B elementwise operation. + * \tparam CDEElementwiseOperation C elementwise operation. + */ +template +struct DeviceGroupedGemmMultiABDFixedNK : DeviceGroupedGemmMultiABD +{ + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0; + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; + virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp index 62f5454b4e..33e03a85e2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -663,7 +663,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i]; const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; - if(!(valid_a_vector_size && valid_a_access_dim)) + if(!((valid_a_vector_size && valid_a_access_dim) || + ABlockTransferSrcScalarPerVector == 1)) { valid_as_access = false; } @@ -682,7 +683,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i]; const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; - if(!(valid_b_vector_size && valid_b_access_dim)) + if(!((valid_b_vector_size && valid_b_access_dim) || + BBlockTransferSrcScalarPerVector == 1)) { valid_bs_access = false; } @@ -698,7 +700,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector read of Ds is always on N dimension. const bool valid_d_access_dim = arg.ds_nz_consecutive_[i]; - if(!(valid_d_vector_size && valid_d_access_dim)) + if(!((valid_d_vector_size && valid_d_access_dim) || + CDEBlockTransferScalarPerVector_NPerBlock == 1)) { valid_ds_access = false; } @@ -712,7 +715,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; // Vector write of E is always on N dimension. const bool valid_e_access_dim = arg.e_nz_consecutive_; - if(!(valid_e_vector_size && valid_e_access_dim)) + if(!((valid_e_vector_size && valid_e_access_dim) || + CDEBlockTransferScalarPerVector_NPerBlock == 1)) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp index 21914d466d..1af2be91f8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp @@ -169,78 +169,6 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD{}; static constexpr auto I3 = Number<3>{}; -#if 0 - static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; - - static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideAs) - { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideAs, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideAs)); - } - }(); - - return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); - } - - static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideBs) - { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideBs)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideBs, I1)); - } - }(); - - return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); - } - - template - static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) - { - const auto e_grid_desc_mraw_nraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), - make_tuple(StrideE, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), - make_tuple(I1, StrideE)); - } - }(); - - return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); - } - - static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, - const std::array& NRaws, - const std::array& DsStride) - { - return generate_tuple( - [&](auto i) { - using DLayout = remove_cvref_t>; - - return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); - }, - Number{}); - } -#endif using ComputeDataType = EDataType; // GridwiseGemm @@ -384,7 +312,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD( - KRaw, NRaw, StrideBs[i]); + NRaw, KRaw, StrideBs[i]); }); // populate pointer, desc for Ds @@ -424,15 +352,6 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD{}( - //[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); - // std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; - } - // private: // pointers typename GridwiseGemm::AsGridPointer p_as_grid_; @@ -578,7 +497,10 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD{}([&](auto i) { using DLayout = remove_cvref_t>; @@ -618,21 +542,21 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD) { if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) { - return false; + all_valid = false; } } else + { + all_valid = false; + } + + if(!all_valid) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp new file mode 100644 index 0000000000..bf8788a3b2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -0,0 +1,851 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const index_t grid_size_grp, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t KBatch = 1; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + const index_t M = gemm_desc_ptr[group_id].M; + const index_t N = gemm_desc_ptr[group_id].N; + const index_t K = gemm_desc_ptr[group_id].K; + + if(M * N * K == 0) + return; + + const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; + const auto StrideBs = gemm_desc_ptr[group_id].StrideBs; + const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; + const auto StrideE = gemm_desc_ptr[group_id].StrideE; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size(); + constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size(); + constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size(); + + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t; + p_as_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_as_grid[i]); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t; + p_bs_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_bs_grid[i]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t; + p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + while(id_local < local_grid_size) + { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + + GridwiseGemm:: + template Run( + p_as_grid_, + p_bs_grid_, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + block_2_etile_map); + + id_off += grid_size_grp; + id_local += grid_size_grp; + } +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = grid_size_grp; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif +} + +template +struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK + : public DeviceGroupedGemmMultiABDFixedNK +{ + using DeviceOp = DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t NumGemmKPrefetchStage = 1; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< + AsDataType, + BsDataType, + ComputeType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + return make_tuple( + // idx_bot[Number<0>{}], + idx_bot[Number<1>{}], + idx_bot[Number<2>{}]); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); + + block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; + + using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + + struct GemmBiasTransKernelArg + { + // pointers + std::array as_ptr_; + std::array bs_ptr_; + std::array ds_ptr_; + void* e_ptr_; + + index_t M_, N_, K_; + std::array StrideAs_; + std::array StrideBs_; + std::array StrideDs_; + index_t StrideE_; + }; + + // Argument + struct Argument : public BaseArgument + { + + void UpdateKBatch(index_t) {} + + Argument(std::vector>&, + std::vector>&, + std::vector>&, + std::vector&, + std::vector& gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) + : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} + { + grid_size_ = 0; + + k_batch_ = 1; + + grouped_gemm_kernel_args_dev = nullptr; + + group_count_ = ck::type_convert(gemm_descs.size()); + + gemm_desc_kernel_arg_.reserve(group_count_); + + index_t group_id = 0; + + sum_of_m = gemm_descs[0].M_; + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + const index_t N = gemm_descs[0].N_; + const index_t K = gemm_descs[0].K_; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_) + { + throw std::runtime_error("wrong! M/N/K is not identical"); + } + + a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + // pointer + std::array p_as_grid; + std::array p_bs_grid; + std::array p_ds_grid; + + static_for<0, NumATensor, 1>{}([&](auto j) { p_as_grid[j] = nullptr; }); + static_for<0, NumBTensor, 1>{}([&](auto j) { p_bs_grid[j] = nullptr; }); + static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; }); + + std::array StrideAs; + std::array StrideBs; + std::array StrideDs; + + const index_t StrideE = gemm_descs[i].stride_C_; + + if(gemm_descs[i].stride_As_.size() != NumATensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_As_.size() does not match NumATensor"); + } + + static_for<0, NumATensor, 1>{}( + [&](auto j) { StrideAs[j] = gemm_descs[i].stride_As_[j]; }); + + if(gemm_descs[i].stride_Bs_.size() != NumBTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor"); + } + + static_for<0, NumBTensor, 1>{}( + [&](auto j) { StrideBs[j] = gemm_descs[i].stride_Bs_[j]; }); + + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + static_for<0, NumDTensor, 1>{}( + [&](auto j) { StrideDs[j] = gemm_descs[i].stride_Ds_[j]; }); + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + // block-to-e-tile map + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + if(group_id * grid_size_grp_ != grid_size_) + { + throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); + } + + grid_size_ += grid_size_grp_; + + // check block-to-E-tile + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } + + gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ + p_as_grid, + p_bs_grid, + p_ds_grid, + nullptr, + AverM, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + }); + + group_id++; + } + + const auto e_grid_desc_sum_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; + + barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); + } + + // private: + index_t group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + const void* grouped_gemm_kernel_args_dev; + + index_t grid_size_; + index_t grid_size_grp_; + index_t barrier_size_grp_; + index_t sum_of_m; + + index_t k_batch_ = 1; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool has_main_k_block_loop = true; + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + if(arg.grouped_gemm_kernel_args_dev == nullptr) + { + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + } + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) { + const auto kernel = kernel_grouped_gemm_xdl_fixed_nk< + GridwiseGemm, + GroupedGemmMultiABDKernelArgument, + GemmSpec, + AsLayout, + BsLayout, + DsLayout, + ELayout, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + }; + + constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; + constexpr auto Set = InMemoryDataOperationEnum::Set; + + if(arg.k_batch_ > 1) + { + if(has_main_k_block_loop) + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + } + else + { + if(has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + { + return false; + } + + bool supported = true; + + // If we use padding we do not support vector loads for dimensions not divisible by vector + // load size. + if constexpr(GemmSpec != GemmSpecialization::Default) + { + // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, + // thus we have to adapt it to the {M,K} or {N,K} layout. + const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; + const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; + + for(index_t i = 0; i < arg.group_count_; ++i) + { + const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + } + } + + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_Xdl_Fixed_NK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + static void SetElementwiseOps(Argument& arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + arg.a_element_op_ = a_element_op; + arg.b_element_op_ = b_element_op; + arg.c_element_op_ = c_element_op; + } + + static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args) + { + arg.grouped_gemm_kernel_args_dev = kernel_args; + } + + // polymorphic + void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), kernel_args); + } + + void SetElementwiseOps(BaseArgument* p_arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) const override + { + + SetElementwiseOps( + *dynamic_cast(p_arg), a_element_op, b_element_op, c_element_op); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * + sizeof(GroupedGemmMultiABDKernelArgument); + } + +#if 0 + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& stream_config = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + p_arg_->p_workspace_ = p_workspace; + + hip_check_error( + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); + } +#endif + + static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } + + // polymorphic + void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override + { + return SetKBatch(*dynamic_cast(p_arg), k_batch); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp b/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp index 6c5895b010..609c4c2f5b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp @@ -11,7 +11,6 @@ #include "ck/host_utility/kernel_launch.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index f6e57aad09..636d34ef68 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck/utility/data_type.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { namespace tensor_operation { @@ -92,6 +92,15 @@ struct Add }; }; +struct Scales +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const + { + y = ck::type_convert(ck::type_convert(x0) * ck::type_convert(x1)); + } +}; + struct Max { template @@ -485,6 +494,19 @@ struct AddFastGelu e = type_convert(x1_f); } + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const + { + const float x0_f = type_convert(c) + type_convert(d); + + float x1_f = 0; + + FastGelu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } + template <> __host__ __device__ constexpr void operator()(bhalf_t& e, const float& c, const bhalf_t& d) const diff --git a/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp index 6d1d6b57c5..d8bac8da7a 100644 --- a/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp @@ -14,6 +14,8 @@ namespace element_wise { template struct UnaryCombinedOp { + __host__ __device__ UnaryCombinedOp() : unary_ops_() {} + __host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {} template @@ -32,6 +34,8 @@ struct UnaryCombinedOp template struct BinaryWithUnaryCombinedOp { + __host__ __device__ BinaryWithUnaryCombinedOp() : binary_op_(), unary_op0_(), unary_op1_() {} + __host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op, UnaryOp0 unary_op0, UnaryOp1 unary_op1) @@ -63,6 +67,11 @@ template struct TrinaryWithUnaryCombinedOp { + __host__ __device__ TrinaryWithUnaryCombinedOp() + : binary_op0_(), binary_op1_(), unary_op0_(), unary_op1_(), unary_op2_() + { + } + __host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0, BinaryOp0 binary_op1, UnaryOp0 unary_op0, diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index eeba8dc8e0..0b8670332d 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -288,10 +288,13 @@ struct ConvertF8RNE struct Scale { - __host__ __device__ Scale(float scale) : scale_(scale) {} + __host__ __device__ Scale(float scale = 1.f) : scale_(scale) {} template - __host__ __device__ void operator()(Y& y, const X& x) const; + __host__ __device__ void operator()(Y& y, const X& x) const + { + y = ck::type_convert(ck::type_convert(x) * scale_); + } template <> __host__ __device__ void operator()(half_t& y, const half_t& x) const @@ -500,6 +503,36 @@ struct FastGelu y = type_convert(y_f); } + + template <> + __device__ void operator()(bhalf_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); + } + + template <> + __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } + + template <> + __host__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } }; // https://paperswithcode.com/method/gelu diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 0f98f9e63d..f4c0a3d911 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -439,7 +439,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle template __host__ __device__ static auto - MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + MakeBGridDescriptor_N_K(const index_t NRaw, const index_t KRaw, const index_t StrideB) { constexpr auto matrix_padder = ck::tensor_operation::device::MatrixPadder{ @@ -463,15 +463,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle template __host__ __device__ static auto - MakeBsGridDescriptor_N_K(const std::array& KRaws, - const std::array& NRaws, + MakeBsGridDescriptor_N_K(const std::array& NRaws, + const std::array& KRaws, const std::array& BsStride) { return generate_tuple( [&](auto i) { using BLayout = remove_cvref_t>; - return MakeBGridDescriptor_N_K(KRaws[i], NRaws[i], BsStride[i]); + return MakeBGridDescriptor_N_K(NRaws[i], KRaws[i], BsStride[i]); }, Number{}); } @@ -574,7 +574,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle { return; } - // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); @@ -595,8 +594,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, Number{}); +#if 0 static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1, "Src and Dst ScalarPerVector must be the same"); +#endif auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< ThisThreadBlock, @@ -626,8 +627,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, Number{}); +#if 0 static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1, "Src and Dst ScalarPerVector must be the same"); +#endif auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< ThisThreadBlock, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 3fdf686523..bcce930fc7 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -10,38 +10,9 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + namespace ck { - -// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory -// and sometimes useless instructions: -// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument -// instead -// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same -// tensor coordinate instead -// 3. Don't use a pointer to VGPR buffer, use vector instead - -namespace detail { -// TODO: How to fix this? It uses an struct instead of lambda because lambda -// doesn't have constructor -template -struct lambda_scalar_per_access -{ - __host__ __device__ constexpr auto operator()(index_t i) const - { - return (i == VectorDim) ? ScalarPerVector : 1; - } -}; - -template -struct lambda_scalar_step_in_vector -{ - __host__ __device__ constexpr auto operator()(index_t i) const - { - return (i == VectorDim) ? 1 : 0; - } -}; -} // namespace detail - // Assume: // 1. src: // 1. SrcDesc is known at compile-time diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp new file mode 100644 index 0000000000..96b95579f5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + return (i == VectorDim) ? ScalarPerVector : 1; + } +}; + +template +struct lambda_scalar_step_in_vector +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + return (i == VectorDim) ? 1 : 0; + } +}; + +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access_for_src_and_dst +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if(i == SrcVectorDim && i == DstVectorDim) + { + return math::lcm(SrcScalarPerVector, DstScalarPerVector); + } + else if(i == SrcVectorDim) + { + return SrcScalarPerVector; + } + else if(i == DstVectorDim) + { + return DstScalarPerVector; + } + else + { + return 1; + } + } +}; + +} // namespace detail + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 699a34418d..96ea04c8fa 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,44 +7,13 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor/static_tensor.hpp" #include "ck/utility/is_detected.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + namespace ck { -namespace detail { -// TODO: How to fix this? It uses an struct instead of lambda because lambda -// doesn't have constructor -template -struct lambda_scalar_per_access_for_src_and_dst -{ - __host__ __device__ constexpr auto operator()(index_t i) const - { - if(i == SrcVectorDim && i == DstVectorDim) - { - return math::lcm(SrcScalarPerVector, DstScalarPerVector); - } - else if(i == SrcVectorDim) - { - return SrcScalarPerVector; - } - else if(i == DstVectorDim) - { - return DstScalarPerVector; - } - else - { - return 1; - } - } -}; - -} // namespace detail - // Assume: // 1. src_desc and dst_desc are not known at compile-time // 2. SrcBuffer and DstBuffer are DynamicBuffer diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp index 299a4b9e7d..1643c244ee 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,9 +8,11 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_space_filling_curve.hpp" #include "ck/utility/is_detected.hpp" +#include "ck/tensor/static_tensor.hpp" + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" namespace ck { - // Thread-level multi-source, multi-destination tensor slice data movement // Assume: // 1. All sources and destinations are DynamicBuffer @@ -70,16 +72,18 @@ struct ThreadwiseTensorSliceTransfer_v7r2 static constexpr auto src_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - using SrcSpaceFillingCurve = SpaceFillingCurve>; - static constexpr auto dst_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); + using SrcSpaceFillingCurve = SpaceFillingCurve, + false>; + using DstSpaceFillingCurve = SpaceFillingCurve>; + remove_cv_t, + false>; __device__ constexpr ThreadwiseTensorSliceTransfer_v7r2( const SrcDescs& src_descs, @@ -139,9 +143,9 @@ struct ThreadwiseTensorSliceTransfer_v7r2 __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs) { // loop over space-filling curve - static_for<0, num_access, 1>{}([&](auto iAccess) { + static_for<0, src_num_access, 1>{}([&](auto iAccess) { auto src_vectors = generate_vectors(); - auto dst_vectors = generate_vectors(); + auto elm_vectors = generate_vectors(); // copy data from src_bufs into src_vectors static_for<0, nSrc, 1>{}([&](auto i) { @@ -199,7 +203,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2 using elem_op_vec_t = typename vector_type::type; - return dst_vectors(iDst).template AsType()(i); + return elm_vectors(iDst).template AsType()(i); }, Number{}); @@ -214,10 +218,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2 unpack2(element_op_, dst_data_refs, src_data_refs); }); - dst_vectors_tuple_(iAccess) = dst_vectors; + elm_vectors_tuple_(iAccess) = elm_vectors; // move coordinate - if constexpr(iAccess.value != num_access - 1) + if constexpr(iAccess.value != src_num_access - 1) { constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); @@ -241,15 +245,113 @@ struct ThreadwiseTensorSliceTransfer_v7r2 }); } + __device__ void TransposeFromElmToDst() + { + using DstData = remove_cvref_t; + + using SrcThreadScratch = + StaticTensorTupleOfVectorBuffer; + + using DstThreadScratch = + StaticTensorTupleOfVectorBuffer; + + SrcThreadScratch elm_thread_scratch_; + DstThreadScratch dst_thread_scratch_; + + elm_thread_scratch_.data_ = + bit_cast(elm_vectors_tuple_); + + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return elm_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}( + [&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; }); + } + + dst_vectors_tuple_ = bit_cast(dst_thread_scratch_.data_); + } + // DstDescs: Tuple // DstBuffers: Tuple template = false> + enable_if_t = false> __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs) { + TransposeFromElmToDst(); + // loop over space-filling curve - static_for<0, num_access, 1>{}([&](auto iAccess) { - auto dst_vectors = dst_vectors_tuple_[iAccess]; + static_for<0, dst_num_access, 1>{}([&](auto iAccess) { + auto dst_vectors = dst_vectors_tuple_[Number{}]; // copy data from buf_vectors into dst_bufs static_for<0, nDst, 1>{}([&](auto i) { @@ -269,7 +371,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2 }); // move coordinate - if constexpr(iAccess.value != num_access - 1) + if constexpr(iAccess.value != dst_num_access - 1) { constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); @@ -312,28 +414,126 @@ struct ThreadwiseTensorSliceTransfer_v7r2 __device__ static constexpr auto GetSrcCoordinateResetStep() { - if constexpr(num_access == 0) + if constexpr(src_num_access == 0) { return typename SrcSpaceFillingCurve::Index{}; } else { - return SrcSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + return SrcSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); } } __device__ static constexpr auto GetDstCoordinateResetStep() { - if constexpr(num_access == 0) + if constexpr(dst_num_access == 0) { return typename DstSpaceFillingCurve::Index{}; } else { - return DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + return DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); } } + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + // constexpr auto src_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + // constexpr auto dst_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + // src_slice_origin_step_idx need to be known at compile-time, for performance reason template __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, @@ -372,11 +572,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2 private: using SrcVectorsType = decltype(generate_vectors()); + using ElmVectorsType = decltype(generate_vectors()); using DstVectorsType = decltype(generate_vectors()); - static constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess(); + static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess(); + static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess(); - StaticallyIndexedArray dst_vectors_tuple_; + StaticallyIndexedArray elm_vectors_tuple_; + StaticallyIndexedArray dst_vectors_tuple_; SrcCoords src_coords_; DstCoords dst_coords_; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp new file mode 100644 index 0000000000..c6a40e3b64 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp @@ -0,0 +1,468 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +#ifdef CK_ENABLE_INT8 +// RRR +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Scales, + AddFastGelu>>>& instances); + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Scales, + Add>>>& instances); + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Scales, + FastGelu>>>& instances); + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Scales, + PassThrough>>>& instances); + +// RCR +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Scales, + AddFastGelu>>>& instances); + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Scales, + Add>>>& instances); + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Scales, + FastGelu>>>& instances); + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Scales, + PassThrough>>>& instances); + +// CRR +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Scales, + AddFastGelu>>>& instances); + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Scales, + Add>>>& instances); + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Scales, + FastGelu>>>& instances); + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Scales, + PassThrough>>>& instances); +#endif + +// GEMM + Add + Gelu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABD> +{ + using DeviceOp = DeviceGemmMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances(op_ptrs); + } + } +#endif + + return op_ptrs; + } +}; + +// GEMM + Add +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABD> +{ + using DeviceOp = DeviceGemmMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(op_ptrs); + } + } +#endif + + return op_ptrs; + } +}; + +// GEMM + Gelu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABD> +{ + using DeviceOp = DeviceGemmMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(op_ptrs); + } + } +#endif + + return op_ptrs; + } +}; + +// GEMM +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABD> +{ + using DeviceOp = DeviceGemmMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(op_ptrs); + } + } +#endif + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp new file mode 100644 index 0000000000..482b7d0b5e --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp @@ -0,0 +1,470 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +// RRR +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Scales, + AddFastGelu>>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Scales, + Add>>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Scales, + FastGelu>>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Scales, + PassThrough>>>& instances); + +// RCR +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Scales, + AddFastGelu>>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Scales, + Add>>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Scales, + FastGelu>>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Scales, + PassThrough>>>& instances); + +// CRR +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Scales, + AddFastGelu>>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Scales, + Add>>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Scales, + FastGelu>>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Scales, + PassThrough>>>& instances); + +// GEMM + Add + Gelu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK> +{ + using DeviceOp = DeviceGroupedGemmMultiABDFixedNK; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +// GEMM + Add +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK> +{ + using DeviceOp = DeviceGroupedGemmMultiABDFixedNK; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +// GEMM + Gelu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK> +{ + using DeviceOp = DeviceGroupedGemmMultiABDFixedNK; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +// GEMM +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK> +{ + using DeviceOp = DeviceGroupedGemmMultiABDFixedNK; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt new file mode 100644 index 0000000000..9cef62a22e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt @@ -0,0 +1,10 @@ +# ONLY XDL_KERNELS +set(GEMM_MULTI_ABD_INSTANCES) + +list(APPEND GEMM_MULTI_ABD_INSTANCES + device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_km_kn_mn_v1_instance.cpp + ) + +add_instance_library(device_gemm_multi_abd_instance ${GEMM_MULTI_ABD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_common.hpp new file mode 100644 index 0000000000..d2a7654077 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_common.hpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +// using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Col; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +// using DsLayout = ck::Tuple; +using ELayout = Row; + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; +using Add = ck::tensor_operation::element_wise::Add; + +using AElementOp = PassThrough; +using BElementOp = Scales; +// using CDEElementOp = AddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances = std::tuple< + // clang-format off + //###############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| K0Per| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //###############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v1 + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 48, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 24, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, LoopSche, PipVer> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp new file mode 100644 index 0000000000..3b8df6b18d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +// using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +// using DsLayout = ck::Tuple; +using ELayout = Row; + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; +using Add = ck::tensor_operation::element_wise::Add; + +using AElementOp = PassThrough; +using BElementOp = Scales; +// using CDEElementOp = AddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances = std::tuple< + // clang-format off + //###############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| K0Per| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //###############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v1 + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 48, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 24, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, LoopSche, PipVer> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp new file mode 100644 index 0000000000..d4d85ef893 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +// using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +// using DsLayout = ck::Tuple; +using ELayout = Row; + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; +using Add = ck::tensor_operation::element_wise::Add; + +using AElementOp = PassThrough; +using BElementOp = Scales; +// using CDEElementOp = AddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple< + // clang-format off + //###############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| K0Per| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //###############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v1 + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 48, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 24, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>, + DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, LoopSche, PipVer> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_km_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_km_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..fe377a9383 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_km_kn_mn_v1_instance.cpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances, + ck::Tuple, + AddFastGelu, + GemmMNKPadding, + PipelineVersion::v1, + LoopScheduler::Default>{}); +} + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances, + ck::Tuple, + Add, + GemmMNKPadding, + PipelineVersion::v1, + LoopScheduler::Default>{}); +} + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances, + ck::Tuple<>, + PassThrough, + GemmMNKPadding, + PipelineVersion::v1, + LoopScheduler::Default>{}); +} + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances, + ck::Tuple<>, + FastGelu, + GemmMNKPadding, + PipelineVersion::v1, + LoopScheduler::Default>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..d97528b4a5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances, + ck::Tuple, + AddFastGelu, + GemmMNKPadding, + PipelineVersion::v1, + LoopScheduler::Default>{}); +} + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances, + ck::Tuple, + Add, + GemmMNKPadding, + PipelineVersion::v1, + LoopScheduler::Default>{}); +} + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances, + ck::Tuple<>, + PassThrough, + GemmMNKPadding, + PipelineVersion::v1, + LoopScheduler::Default>{}); +} + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances, + ck::Tuple<>, + FastGelu, + GemmMNKPadding, + PipelineVersion::v1, + LoopScheduler::Default>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp new file mode 100644 index 0000000000..bc64513c08 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances, + ck::Tuple, + AddFastGelu, + GemmMNKPadding, + PipelineVersion::v1, + LoopScheduler::Default>{}); +} + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances, + ck::Tuple, + Add, + GemmMNKPadding, + PipelineVersion::v1, + LoopScheduler::Default>{}); +} + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances, + ck::Tuple<>, + PassThrough, + GemmMNKPadding, + PipelineVersion::v1, + LoopScheduler::Default>{}); +} + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances, + ck::Tuple<>, + FastGelu, + GemmMNKPadding, + PipelineVersion::v1, + LoopScheduler::Default>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt new file mode 100644 index 0000000000..e38c82d396 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt @@ -0,0 +1,10 @@ +# ONLY XDL_KERNELS +set(GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES) + +list(APPEND GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES + device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp + ) + +add_instance_library(device_grouped_gemm_fixed_nk_multi_abd_instance ${GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp new file mode 100644 index 0000000000..c9d61513de --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +// using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Col; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +// using DsLayout = ck::Tuple; +using ELayout = Row; + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using Add = ck::tensor_operation::element_wise::Add; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = Scales; +// using CDEElementOp = AddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances = std::tuple< + // clang-format off + //######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM|NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization|Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp new file mode 100644 index 0000000000..8842391fec --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +// using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +// using DsLayout = ck::Tuple; +using ELayout = Row; + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using Add = ck::tensor_operation::element_wise::Add; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = Scales; +// using CDEElementOp = AddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances = std::tuple< + // clang-format off + //######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM|NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization|Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp new file mode 100644 index 0000000000..75d9fd1d39 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +// using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +// using DsLayout = ck::Tuple; +using ELayout = Row; + +using Scales = ck::tensor_operation::element_wise::Scales; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using Add = ck::tensor_operation::element_wise::Add; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = Scales; +// using CDEElementOp = AddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple< + // clang-format off + //######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM|NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization|Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..0d9af198cf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + ck::Tuple, + ck::Tuple, + AddFastGelu, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + ck::Tuple, + ck::Tuple, + Add, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + ck::Tuple<>, + ck::Tuple<>, + PassThrough, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances< + ck::Tuple<>, + ck::Tuple<>, + FastGelu, + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..0f81855489 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + ck::Tuple, + ck::Tuple, + AddFastGelu, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + ck::Tuple, + ck::Tuple, + Add, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + ck::Tuple<>, + ck::Tuple<>, + PassThrough, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances< + ck::Tuple<>, + ck::Tuple<>, + FastGelu, + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..67e831e6d6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + ck::Tuple, + ck::Tuple, + AddFastGelu, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + ck::Tuple, + ck::Tuple, + Add, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + ck::Tuple<>, + ck::Tuple<>, + PassThrough, + GemmMNKPadding>{}); +} + +void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances< + ck::Tuple<>, + ck::Tuple<>, + FastGelu, + GemmMNKPadding>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck