diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 6b75ff76ce..9c3b1a7071 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -50,6 +50,9 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple add_example_executable(example_grouped_gemm_wmma_fixed_nk_fp16 grouped_gemm_wmma_fixed_nk_fp16.cpp) add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_fixed_nk_fp16) +add_example_executable(example_grouped_gemm_wmma_fixed_nk_bias_fp16 grouped_gemm_wmma_fixed_nk_bias_fp16.cpp) +add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_fixed_nk_bias_fp16) + list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_bias_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_bias_fp16.cpp new file mode 100644 index 0000000000..cc475fdccc --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_bias_fp16.cpp @@ -0,0 +1,406 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/tuple.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +using ck::DeviceMem; +using ck::hip_check_error; +using ck::HostTensorDescriptor; +using ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using SplitKAdd = ck::tensor_operation::element_wise::SplitKAdd; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; + +using CDEElementOp = SplitKAdd; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Fixed_Nk + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>; + +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Ds; + std::vector stride_Es; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + int group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_a, p_b; + std::vector> p_ds = {}; + std::vector p_e; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> d0_tensors; + std::vector> e_host_tensors; + std::vector> e_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + e_host_tensors.reserve(group_count); + e_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, d0_tensors_device, + e_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + e_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; ++i) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + d0_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Ds[i], D0Layout{}))); + e_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Es[i], ELayout{}))); + e_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Es[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " e_m_n: " << e_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + } + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_tensors_device.emplace_back(std::make_unique( + sizeof(ADataType) * problem_size.Ms[i] * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + d0_tensors_device.emplace_back(std::make_unique( + sizeof(D0DataType) * problem_size.Ms[i] * problem_size.Ns[i])); + + e_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data(), + d0_tensors[i].mDesc.GetElementSpaceSize() * + sizeof(D0DataType)); + e_tensors_device[i]->SetZero(); + + p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_ds.push_back(std::array{d0_tensors_device[i]->GetDeviceBuffer()}); + p_e.push_back(e_tensors_device[i]->GetDeviceBuffer()); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + problem_size.stride_Es[i], + {problem_size.stride_Ds[i]}}); + + grouped_gemm_kernel_args_.push_back( + {a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + e_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + std::array{problem_size.stride_Ds[i]}, + problem_size.stride_Es[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + e_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data(), + e_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + e_host_tensors[i], + a_element_op, + b_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + e_host_tensors[i](m, n), e_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(e_device_tensors[i], e_host_tensors[i]); + } + + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + } + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + if(argc == 1) + { + // use default cases + } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + printf("arg5: group count (default=16)"); + exit(0); + } + + // Lambda to get stride based on layout + auto get_stride = [](auto layout, auto row_dim, auto col_dim) { + if constexpr(std::is_same_v) + { + return col_dim; + } + else + { + return row_dim; + } + }; + + for(int i = 0; i < problem_size.group_count; ++i) + { + problem_size.Ms.push_back(256 + 256 * i); + + problem_size.Ns.push_back(512); + problem_size.Ks.push_back(512); + + problem_size.stride_As.push_back( + get_stride(ALayout{}, problem_size.Ms[i], problem_size.Ks[i])); + problem_size.stride_Bs.push_back( + get_stride(BLayout{}, problem_size.Ks[i], problem_size.Ns[i])); + problem_size.stride_Ds.push_back( + get_stride(D0Layout{}, problem_size.Ms[i], problem_size.Ns[i])); + problem_size.stride_Es.push_back( + get_stride(ELayout{}, problem_size.Ms[i], problem_size.Ns[i])); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index 83ab9774ce..8a9afc1733 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -7,22 +7,55 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/utility/env.hpp" #include "ck/host_utility/hip_check_error.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/utility/scheduler_enum.hpp" #include "ck/utility/tuple.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { +namespace element_wise { +struct SplitKAdd +{ + static constexpr const char* name = "SplitKAdd"; + + __host__ __device__ void set_kbatch(const index_t& id, const index_t& total) + { + kbatch_id = id; + KBatch = total; + } + + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const + { + if(kbatch_id == KBatch - 1) + { + add_op(y, x0, x1); + } + else + { + passthrough_op(y, x0); + } + } + + private: + index_t kbatch_id = 0; + index_t KBatch = 1; + static constexpr auto add_op = Add{}; + static constexpr auto passthrough_op = PassThrough{}; +}; +} // namespace element_wise + namespace device { template {}]); + + auto c_element_op_copy(c_element_op); + if constexpr(std::is_same_v) + { + c_element_op_copy.set_kbatch(kbatch_id, k_batch_); + } + KernelArgument kernel_arg{std::array{gemmTransKernelArg.p_a_grid}, std::array{gemmTransKernelArg.p_b_grid}, gemmTransKernelArg.p_ds_grid, @@ -124,15 +172,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) k_batch_, a_element_op, b_element_op, - c_element_op, + c_element_op_copy, false}; - const auto block_2_etile_map = - GroupedGemmBlock2ETileMap(local_b2c_tile_map, group_start, id_off); - - const auto tile_index = - block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - const auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(kernel_arg, tile_index[Number<0>{}]); @@ -813,7 +855,9 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK) + ck::tensor_operation::element_wise::PassThrough> && + !std::is_same_v) { if(arg.k_batch_ > 1) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 0bbf725521..a843a94777 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -8,27 +8,28 @@ #include #endif -#include "ck/utility/env.hpp" -#include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp" -#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp" -#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp" +#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp index 2a34e3068b..1221dc9099 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp @@ -3,13 +3,14 @@ #pragma once -#include -#include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" - #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" +#include "ck/utility/type.hpp" + +#include +#include namespace ck { namespace tensor_operation { @@ -70,20 +71,52 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances( PassThrough, Add>>>& instances); +void add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances( + std::vector< + std::unique_ptr>>& + instances); + +void add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances( + std::vector< + std::unique_ptr>>& + instances); + template struct DeviceOperationInstanceFactory< ck::tensor_operation::device::DeviceGroupedGemmFixedNK && is_same_v && is_same_v) { +#if defined(CK_USE_XDL) add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs); +#endif } - if constexpr(is_same_v && is_same_v && - is_same_v) + else if constexpr(is_same_v && is_same_v && + is_same_v) { +#if defined(CK_USE_XDL) add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs); +#endif } } // fp32_output - if constexpr(is_same_v && is_same_v && - is_same_v) + else if constexpr(is_same_v && is_same_v && + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(op_ptrs); +#endif } - if constexpr(is_same_v && is_same_v && + else if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_USE_XDL + add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(op_ptrs); +#endif + } + } + return op_ptrs; + } +}; + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceGroupedGemmFixedNK; + + static auto GetInstances() + { + std::vector> op_ptrs; + + // fp16_output + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(op_ptrs); +#if defined(CK_USE_WMMA) + add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs); +#endif + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { +#if defined(CK_USE_WMMA) + add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs); +#endif } } return op_ptrs; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt index 950784eb46..84bb1fc2ea 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt @@ -1,11 +1,14 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_grouped_gemm_bias_instance device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp + + device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..4972d3e0e5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,95 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/tuple.hpp" + +#include +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using SplitKAdd = ck::tensor_operation::element_wise::SplitKAdd; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = SplitKAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //#############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..b671fee4dc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,98 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/tuple.hpp" + +#include +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using SplitKAdd = ck::tensor_operation::element_wise::SplitKAdd; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = SplitKAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //#############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedGemm_Wmma_Fixed_Nk, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_bias_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_bias_impl.hpp new file mode 100644 index 0000000000..488a0fa660 --- /dev/null +++ b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_bias_impl.hpp @@ -0,0 +1,411 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/env.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/type.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck { +namespace profiler { + +template +bool profile_grouped_gemm_fixed_nk_bias_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideDs, + const std::vector& StrideEs, + const std::vector& kbatches = {1}, + int n_warmup = 1, + int n_iter = 10) +{ + bool pass = true; + using ComputeDataType = ADataType; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); + } + }; + + std::size_t group_count = Ms.size(); + + if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && + group_count == StrideBs.size() && group_count == StrideDs.size() && + group_count == StrideEs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); + } + + using D0DataType = remove_cvref_t{}, DsDataType>>; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> d0_tensors; + std::vector> e_host_tensors; + std::vector> e_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + e_host_tensors.reserve(group_count); + e_device_tensors.reserve(group_count); + + double max_abs_in_val = 0.f; + int sum_of_m = 0; + + using D0Layout = remove_cvref_t{}, DsLayout>>; + + for(std::size_t i = 0; i < group_count; ++i) + { + sum_of_m += Ms[i]; + a_tensors.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); + b_tensors.push_back( + Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); + d0_tensors.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideDs[i], D0Layout{}))); + e_host_tensors.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{}))); + e_device_tensors.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{}))); + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_tensors[i].mDesc + << ", b_k_n[" << i << "]:" << b_tensors[i].mDesc << ", d_m_n[" << i + << "]:" << d0_tensors[i].mDesc << ", e_m_n_device_results[" << i + << "]:" << e_device_tensors[i].mDesc << std::endl; + } + switch(init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_tensors[i]); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_tensors[i]); + max_abs_in_val = 10.f; + break; + default: + ck::utils::FillUniformDistribution{0.0f, 1.0f}(a_tensors[i]); + ck::utils::FillUniformDistribution{-0.5f, 0.5f}(b_tensors[i]); + max_abs_in_val = 1.0f; + } + ck::utils::FillUniformDistribution{-0.5, 0.5}(d0_tensors[i]); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CDEElementOp = ck::tensor_operation::element_wise::SplitKAdd; + + constexpr auto a_element_op = AElementOp{}; + constexpr auto b_element_op = BElementOp{}; + constexpr auto cde_element_op = CDEElementOp{}; + + using DeviceMemPtr = std::unique_ptr; + std::vector a_tensors_device, b_tensors_device, d0_tensors_device, + e_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + e_tensors_device.reserve(group_count); + + std::vector p_a, p_b; + std::vector> p_ds; + std::vector p_e; + + p_a.reserve(group_count); + p_b.reserve(group_count); + p_ds.reserve(group_count); + p_e.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(std::size_t i = 0; i < group_count; ++i) + { + a_tensors_device.emplace_back(std::make_unique( + sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpaceSize())); + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpaceSize())); + d0_tensors_device.emplace_back(std::make_unique( + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSpaceSize())); + e_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSpaceSize())); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + + gemm_descs.push_back( + {sum_of_m, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideEs[i], {StrideDs[i]}}); + + p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_ds.push_back(std::array{d0_tensors_device[i]->GetDeviceBuffer()}); + p_e.push_back(e_tensors_device[i]->GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back( + {a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + e_tensors_device[i]->GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + std::array{StrideDs[i]}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + if(op_ptrs.size() <= 0) + { + std::cerr << "Skip! no device GEMM instance found" << std::endl; + return true; + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_kbatch = 0; + + if(do_verification) + { + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm< + ADataType, + BDataType, + EDataType, + AccDataType, + AElementOp, + BElementOp, + ck::tensor_operation::element_wise::PassThrough>; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + e_host_tensors[i], + a_element_op, + b_element_op, + ck::tensor_operation::element_wise::PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < Ms[i]; ++m) + { + for(int n = 0; n < Ns[i]; ++n) + { + cde_element_op( + e_host_tensors[i](m, n), e_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + } + } + + // profile device GEMM instances + for(auto& gemm_ptr : op_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + + DeviceMem grouped_gemm_kernel_args_dev( + gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + + gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + for(std::size_t j = 0; j < kbatches.size(); ++j) + { + + auto kbatch_curr = kbatches[j]; + + gemm_ptr->SetKBatch(argument_ptr.get(), kbatch_curr); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + e_tensors_device[i]->SetZero(); + } + + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + + if(do_verification) + { + bool instance_pass = true; + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + e_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data()); + auto atol = ck::utils::get_absolute_threshold( + max_abs_in_val, gemm_descs[i].K_); + auto rtol = ck::utils::get_relative_threshold( + gemm_descs[i].K_); + + instance_pass = + instance_pass && ck::utils::check_err(e_device_tensors[i], + e_host_tensors[i], + "Error: Incorrect results!", + rtol, + atol); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_tensors[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_tensors[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d0: ", d0_tensors[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "e_device: ", e_device_tensors[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "e_host : ", e_host_tensors[i].mData, ",") + << std::endl; + } + } + + std::cout << "Instance: " << gemm_name << "; KBatch: " << kbatch_curr << " " + << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; + + pass = pass && instance_pass; + } + + float ave_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); + + if(time_kernel) + { + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + + num_btype += + sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(D0DataType) * Ms[i] * Ns[i] + sizeof(EDataType) * Ms[i] * Ns[i]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " + << kbatch_curr << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + } + } + else + { + std::cout << "Instance: " << gemm_name + << ", does not support this GEMM problem (KBatch: " << kbatch_curr << ")" + << std::endl; + } + } + } + + if(time_kernel) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch + << std::endl; + } + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index 4fbdcc3583..3d00d5044b 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -25,6 +25,12 @@ if (CK_USE_XDL OR CK_USE_WMMA) add_dependencies(test_grouped_gemm test_grouped_gemm_fixed_nk) endif() + add_gtest_executable(test_grouped_gemm_fixed_nk_bias test_grouped_gemm_fixed_nk_bias.cpp) + if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_fixed_nk_bias PRIVATE utility device_grouped_gemm_bias_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_fixed_nk_bias) + endif() + add_gtest_executable(test_grouped_gemm_multi_abd_fixed_nk test_grouped_gemm_multi_abd_fixed_nk.cpp) if(result EQUAL 0) target_link_libraries(test_grouped_gemm_multi_abd_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_multi_abd_instance) diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk_bias.cpp b/test/grouped_gemm/test_grouped_gemm_fixed_nk_bias.cpp new file mode 100644 index 0000000000..56d4051b79 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk_bias.cpp @@ -0,0 +1,306 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" +#include "profiler/profile_grouped_gemm_fixed_nk_bias_impl.hpp" +#include "gtest/gtest.h" + +#include +#include +#include +#include + +static ck::index_t param_mask = 0xffffff; +static ck::index_t instance_index = -1; + +using FP32 = float; +using FP16 = ck::half_t; +using BF16 = ck::bhalf_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using SplitKAdd = ck::tensor_operation::element_wise::SplitKAdd; + +using CDEElementOp = SplitKAdd; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple, FP16, Row, Row, ck::Tuple, Row>, + std::tuple, FP16, Row, Col, ck::Tuple, Row> +>; +// clang-format on + +template +class TestGroupedGemmFixedNKBias : public testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using DsDataType = std::tuple_element_t<2, Tuple>; + using EDataType = std::tuple_element_t<3, Tuple>; + using ALayout = std::tuple_element_t<4, Tuple>; + using BLayout = std::tuple_element_t<5, Tuple>; + using DsLayout = std::tuple_element_t<6, Tuple>; + using ELayout = std::tuple_element_t<7, Tuple>; + using AccDataType = FP32; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // integer value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + static constexpr int n_warmup_ = 0; + static constexpr int n_iter_ = 1; + + std::vector k_batches_; + + bool IsSplitKSupported() + { + // gfx11 does not support split-K due to missing atomic add for fp16/bf16 + // Technically, we could still use split-K for fp32, but we currently don't have + // instances for it so we disable it entirely + constexpr bool require_16bit_atomic_add = + std::is_same_v || std::is_same_v; + bool missing_atomic_add = require_16bit_atomic_add && ck::is_gfx11_supported(); + + // CDE element operators are not supported in combination with split K + constexpr bool has_cde_element_operator = + !std::is_same_v && !std::is_same_v; + + return !missing_atomic_add && !has_cde_element_operator; + } + + void SetUp() override + { + if(!IsSplitKSupported()) + { + k_batches_ = {1}; + } + else + { + k_batches_ = {1, 2, 3, 4, 8}; + } + } + + private: + template + void SetStrides(std::vector& strides, + const std::vector& rows, + const std::vector& cols) const + { + if(std::is_same_v) + { + for(const auto c : cols) + { + strides.emplace_back(c); + } + } + else if(std::is_same_v) + { + for(const auto r : rows) + { + strides.emplace_back(r); + } + } + } + + template + void SetTupleStrides(std::vector& strides, + const std::vector& rows, + const std::vector& cols) const + { + if constexpr(Layouts::Size() > 0) + { + // As of now multi ABD implementation supports only tensors with matching layouts. + using Layout = ck::remove_cvref_t{}, Layouts>>; + SetStrides(strides, rows, cols); + } + } + + public: + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs = {}, + const std::vector& StrideBs = {}, + const std::vector& StrideDs = {}, + const std::vector& StrideEs = {}) + { + std::vector stride_as = StrideAs; + std::vector stride_bs = StrideBs; + std::vector stride_ds = StrideDs; + std::vector stride_es = StrideEs; + + if(stride_as.empty()) + { + SetStrides(stride_as, Ms, Ks); + } + if(stride_bs.empty()) + { + SetStrides(stride_bs, Ks, Ns); + } + if(stride_ds.empty()) + { + SetTupleStrides(stride_ds, Ms, Ns); + } + if(stride_es.empty()) + { + SetStrides(stride_es, Ms, Ns); + } + + std::vector k_batches; + for(size_t i = 0; i < k_batches_.size(); i++) + { + if(param_mask & (1 << i)) + { + k_batches.push_back(k_batches_[i]); + } + } + + RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_ds, stride_es, k_batches); + } + + void RunSingle(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideDs, + const std::vector& StrideEs, + const std::vector& kbatches) + { + bool pass = ck::profiler::profile_grouped_gemm_fixed_nk_bias_impl(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideDs, + StrideEs, + kbatches, + n_warmup_, + n_iter_); + EXPECT_TRUE(pass); + } +}; + +TYPED_TEST_SUITE(TestGroupedGemmFixedNKBias, KernelTypes); + +TYPED_TEST(TestGroupedGemmFixedNKBias, TinyCases) +{ + const std::vector Ms{2, 1}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemmFixedNKBias, SmallCases) +{ + const std::vector Ms{2, 1, 3, 4, 5}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemmFixedNKBias, MidCases) +{ + const std::vector Ms{167, 183, 177, 153, 139, 204}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemmFixedNKBias, Regular) +{ + const std::vector Ms{64, 128, 256}; + constexpr int N = 768; + constexpr int K = 320; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemmFixedNKBias, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemmFixedNKBias, TestLargeKBatch) +{ + // In some cases Split K is not supported. Running this test would fail since no instance will + // be supported, so we skip the test + if(!this->IsSplitKSupported()) + { + GTEST_SKIP() << "Split-K not supported for for the current configuration (FP16/BF16 on " + "GFX11, or using CDE element-wise operation)"; + } + + const std::vector Ms{188, 210}; + constexpr int N = 768; + constexpr int K = 4096; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->k_batches_ = {32, 64}; + + this->Run(Ms, Ns, Ks); +} + +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + if(argc == 1) {} + else if(argc == 3) + { + param_mask = strtol(argv[1], nullptr, 0); + instance_index = atoi(argv[2]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; + } + return RUN_ALL_TESTS(); +}