diff --git a/Jenkinsfile b/Jenkinsfile index b4adc5de95..15be3e540c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -379,23 +379,23 @@ pipeline { } } } - //stage("Client App") - //{ - // parallel - // { - // stage("Run Client App") - // { - // agent{ label rocmnode("gfx908")} - // environment{ - // setup_args = """ -D -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " """ - // execute_args = """ cd ../test/client_app && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" .. && make """ - // } - // steps{ - // buildHipClangJobAndReboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') - // } - // } - // } - //} + stage("Client App") + { + parallel + { + stage("Run Client App") + { + agent{ label rocmnode("gfx908")} + environment{ + setup_args = """ -D -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc .. && make -j """ + } + steps{ + buildHipClangJobAndReboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + } + } + } + } stage("Performance Tests") { parallel diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt new file mode 100644 index 0000000000..1064abc8fa --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) +target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_operations) diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp new file mode 100644 index 0000000000..bdd6e05029 --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using D0DataType = F16; +using D1DataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using ELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 9) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD0 = std::stoi(argv[6]); + StrideD1 = std::stoi(argv[7]); + StrideE = std::stoi(argv[8]); + } + else + { + printf("arg1 to 8: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem d0_m_n_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, D0Layout{})); + SimpleDeviceMem d1_m_n_device_buf(sizeof(D1DataType) * + f_matrix_space_size(M, N, StrideD1, D1Layout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + // add device op instances + const auto op_ptrs = ck::tensor_operation::device::device_gemm_instance:: + get_device_gemm_add_add_fastgelu_instances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + } + + return 0; +} diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt new file mode 100644 index 0000000000..192959662a --- /dev/null +++ b/client_example/CMakeLists.txt @@ -0,0 +1,9 @@ +cmake_minimum_required(VERSION 3.15) +project(ck_app) +add_compile_options(-std=c++17) + +find_package(composable_kernel 1.0.0 COMPONENTS device_operations) +find_package(hip REQUIRED PATHS /opt/rocm) +message(STATUS "Build with HIP ${hip_VERSION}") + +add_subdirectory(02_gemm_add_add_fastgelu) diff --git a/client_example/README.md b/client_example/README.md new file mode 100644 index 0000000000..dc6b9c48fc --- /dev/null +++ b/client_example/README.md @@ -0,0 +1,32 @@ +## +Client application links to CK library, and therefore CK library needs to be installed before building client applications. + +## Docker script +```bash +docker run \ +-it \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm5.1-tf2.6-dev \ +/bin/bash +``` + +## Build +```bash +mkdir -p client_example/build +cd client_example/build +``` + +```bash +cmake \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +.. +``` + +### Build client example +```bash + make -j +``` diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 19cb07e515..0575c0bd9e 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -84,8 +84,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; int main(int argc, char* argv[]) { @@ -216,24 +221,17 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_f32_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - bf16_to_f32_(a_m_k, a_f32_m_k); - bf16_to_f32_(b_k_n, b_f32_k_n); - bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_argument = ref_gemm.MakeArgument( - a_f32_m_k, b_f32_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); ref_invoker.Run(ref_argument); - return ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData) ? 0 : 1; + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; } return 0; diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp new file mode 100644 index 0000000000..4fc953b3a6 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemm : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t Batch) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceBatchedGemmPtr = std::unique_ptr< + DeviceBatchedGemm>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp new file mode 100644 index 0000000000..036eb3df4b --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmReduce : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + void* p_dxs, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsReduceAccElementwiseOperation dxs_out_element_op, + ck::index_t Batch) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceBatchedGemmReducePtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp index c24ec54e56..5ae610fc8c 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -10,7 +10,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" #include "ck/device_utility/device_prop.hpp" @@ -111,7 +111,7 @@ __global__ void ignore = d_grid_desc_mblock_mperblock; ignore = compute_base_ptr_of_batch_; ignore = block_2_ctile_map; -#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__)) +#endif } // Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle @@ -169,11 +169,11 @@ template struct DeviceBatchedGemmReduce_Xdl_CShuffle - : public DeviceGemmReduce + : public DeviceBatchedGemmReduce { using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; @@ -594,12 +594,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle CElementwiseOperation c_element_op, DxsInElementwiseOperation dxs_in_element_op, DxsReduceAccElementwiseOperation dxs_out_element_op, - index_t BatchCount) + index_t Batch) : p_a_grid_{p_a_grid}, p_b_grid_{p_b_grid}, p_c_grid_{p_c_grid}, p_ds_grid_{p_ds_grid}, - BatchCount_(BatchCount), + Batch_(Batch), a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, @@ -637,7 +637,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle const BDataType* p_b_grid_; CDataType* p_c_grid_; DPtrsGlobal p_ds_grid_; - index_t BatchCount_; + index_t Batch_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; CGridDesc_M_N c_grid_desc_m_n_; @@ -663,7 +663,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle { #if 0 { - std::cout << "arg.BatchCount_ = " << arg.BatchCount_ << std::endl; + std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl; std::cout << "arg.a_grid_desc_ak0_m_ak1_{" << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " @@ -692,7 +692,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle } const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_; const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); @@ -728,7 +728,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle arg.p_b_grid_, arg.p_c_grid_, arg.p_ds_grid_, - arg.BatchCount_, + arg.Batch_, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_, @@ -771,7 +771,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle arg.p_b_grid_, arg.p_c_grid_, arg.p_ds_grid_, - arg.BatchCount_, + arg.Batch_, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_, @@ -839,7 +839,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle CElementwiseOperation c_element_op, DxsInElementwiseOperation dxs_in_element_op, DxsReduceAccElementwiseOperation dxs_out_element_op, - index_t BatchCount) + index_t Batch) { return Argument{p_a, p_b, @@ -856,7 +856,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle c_element_op, dxs_in_element_op, dxs_out_element_op, - BatchCount}; + Batch}; } static auto MakeInvoker() { return Invoker{}; } @@ -878,7 +878,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle CElementwiseOperation c_element_op, DxsInElementwiseOperation dxs_in_element_op, DxsReduceAccElementwiseOperation dxs_out_element_op, - index_t BatchCount) override + index_t Batch) override { DPtrsGlobal dxs_tuple = *(static_cast(p_dxs)); return std::make_unique(static_cast(p_a), @@ -896,7 +896,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle c_element_op, dxs_in_element_op, dxs_out_element_op, - BatchCount); + Batch); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp index 0b5ade2544..c63dfd2c53 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp @@ -10,7 +10,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" #include "ck/device_utility/device_prop.hpp" @@ -152,7 +152,7 @@ template struct DeviceBatchedGemmXdl - : public DeviceGemm + : public DeviceBatchedGemm { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -339,11 +339,11 @@ struct DeviceBatchedGemmXdl AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, - index_t BatchCount) + index_t Batch) : p_a_grid_{p_a_grid}, p_b_grid_{p_b_grid}, p_c_grid_{p_c_grid}, - BatchCount_(BatchCount), + Batch_(Batch), a_grid_desc_k0_m_k1_{ DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, b_grid_desc_k0_n_k1_{ @@ -376,7 +376,7 @@ struct DeviceBatchedGemmXdl const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; - index_t BatchCount_; + index_t Batch_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; @@ -420,7 +420,7 @@ struct DeviceBatchedGemmXdl } const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_; const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -451,7 +451,7 @@ struct DeviceBatchedGemmXdl arg.p_a_grid_, arg.p_b_grid_, arg.p_c_grid_, - arg.BatchCount_, + arg.Batch_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, @@ -485,7 +485,7 @@ struct DeviceBatchedGemmXdl arg.p_a_grid_, arg.p_b_grid_, arg.p_c_grid_, - arg.BatchCount_, + arg.Batch_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, @@ -539,7 +539,7 @@ struct DeviceBatchedGemmXdl AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, - index_t BatchCount) + index_t Batch) { return Argument{p_a, p_b, @@ -555,7 +555,7 @@ struct DeviceBatchedGemmXdl a_element_op, b_element_op, c_element_op, - BatchCount}; + Batch}; } static auto MakeInvoker() { return Invoker{}; } @@ -573,7 +573,7 @@ struct DeviceBatchedGemmXdl AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, - index_t BatchCount) override + index_t Batch) override { return std::make_unique(static_cast(p_a), static_cast(p_b), @@ -589,7 +589,7 @@ struct DeviceBatchedGemmXdl a_element_op, b_element_op, c_element_op, - BatchCount); + Batch); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp new file mode 100644 index 0000000000..5950d8f8dd --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmSplitK : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmSplitKPtr = std::unique_ptr< + DeviceGemmSplitK>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp index 3be6283e48..9d24a4932d 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp @@ -10,7 +10,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp" #include "ck/device_utility/device_prop.hpp" @@ -57,7 +57,7 @@ template struct DeviceGemmXdlSplitK - : public DeviceGemm + : public DeviceGemmSplitK { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp index 1baaae4659..f484de324a 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp @@ -10,7 +10,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp" #include "ck/device_utility/device_prop.hpp" @@ -59,7 +59,7 @@ template struct DeviceGemmXdlSplitKCShuffle - : public DeviceGemm + : public DeviceGemmSplitK { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -420,21 +420,22 @@ struct DeviceGemmXdlSplitKCShuffle arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * sizeof(CDataType))); - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); }; if(has_main_k0_block_loop) diff --git a/library/CMakeLists.txt b/library/CMakeLists.txt index aa18026932..a92fae9e26 100644 --- a/library/CMakeLists.txt +++ b/library/CMakeLists.txt @@ -1,3 +1,3 @@ -add_subdirectory(src/host_tensor) add_subdirectory(src/tensor_operation_instance/gpu) +add_subdirectory(src/host_tensor) add_subdirectory(src/utility) diff --git a/library/include/ck/library/host_tensor/host_tensor.hpp b/library/include/ck/library/host_tensor/host_tensor.hpp index ac1e7dafd7..87e98f6e54 100644 --- a/library/include/ck/library/host_tensor/host_tensor.hpp +++ b/library/include/ck/library/host_tensor/host_tensor.hpp @@ -364,13 +364,8 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens, { } -void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout); - #if 1 // FIXME: remove -void bf16_to_f32_(const Tensor& src, Tensor& dst); -#endif - template float check_error(const Tensor& ref, const Tensor& result) { @@ -416,3 +411,4 @@ float check_error(const Tensor& ref, const Tensor& result) return linf_error; } +#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp index 680ced1629..06e74a9e9a 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp @@ -62,20 +62,20 @@ struct ReferenceBatchedGemm : public device::BaseOperator for(int k = 0; k < K; ++k) { - float v_a; - float v_b; + ADataType v_a; + BDataType v_b; - arg.a_element_op_(v_a, static_cast(arg.a_g_m_k_(g, m, k))); - arg.b_element_op_(v_b, static_cast(arg.b_g_k_n_(g, k, n))); + arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k)); + arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n)); - v_acc += v_a * v_b; + v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); } float v_c; arg.c_element_op_(v_c, v_acc); - arg.c_g_m_n_(g, m, n) = v_c; + arg.c_g_m_n_(g, m, n) = ck::type_convert(v_c); }; make_ParallelTensorFunctor(f_gmk_gkn_gmn, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index a1047d51f8..e3dd4de5df 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -63,20 +63,21 @@ struct ReferenceGemm : public device::BaseOperator for(int k = 0; k < K; ++k) { - AccDataType v_a; - AccDataType v_b; + ADataType v_a; + BDataType v_b; - arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); - arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + arg.b_element_op_(v_b, arg.b_k_n_(k, n)); - v_acc += v_a * v_b; + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); } AccDataType v_c; arg.c_element_op_(v_c, v_acc); - arg.c_m_n_(m, n) = v_c; + arg.c_m_n_(m, n) = ck::type_convert(v_c); }; make_ParallelTensorFunctor( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_batched_gemm_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_batched_gemm_instance.hpp new file mode 100644 index 0000000000..6379ac26cd --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_batched_gemm_instance.hpp @@ -0,0 +1,203 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_batched_gemm_instance { + +using DeviceBatchedGemmNoOpPtr = ck::tensor_operation::device::DeviceBatchedGemmPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( + std::vector&); +void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( + std::vector&); + +template +auto get_device_batched_gemm_instances() +{ + std::vector op_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(op_ptrs); + } + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs); + } + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(op_ptrs); + } + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_batched_gemm_instance:: + add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(op_ptrs); + } + } + + return op_ptrs; +} + +} // namespace device_batched_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp new file mode 100644 index 0000000000..6aa33e4d20 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using DeviceGemmAddAddFastGeluPtr = ck::tensor_operation::device::DeviceGemmMultipleDPtr< + 2, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector&); +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector&); + +template +auto get_device_gemm_add_add_fastgelu_instances() +{ + std::vector op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && + is_same_v && + is_same_v) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + op_ptrs); + } + } + + return op_ptrs; +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_instance.hpp new file mode 100644 index 0000000000..665b63c942 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_instance.hpp @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector&); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector&); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector&); + +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector&); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector&); + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector&); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector&); + +template +auto get_device_gemm_instances() +{ + std::vector op_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs); + } + } + + return op_ptrs; +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_splitk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_splitk_instance.hpp new file mode 100644 index 0000000000..c1fa54ad2a --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_splitk_instance.hpp @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using DeviceGemmSplitKNoOpPtr = ck::tensor_operation::device::DeviceGemmSplitKPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances( + std::vector&); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances( + std::vector&); + +template +auto get_device_gemm_splitk_instances() +{ + std::vector op_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } + + return op_ptrs; +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/host_tensor/host_tensor.cpp b/library/src/host_tensor/host_tensor.cpp index 94783b73c9..dc9f5699dc 100644 --- a/library/src/host_tensor/host_tensor.cpp +++ b/library/src/host_tensor/host_tensor.cpp @@ -54,25 +54,3 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) return os; } - -void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os) -{ - os << "dim " << desc.GetNumOfDimension() << ", "; - - os << "lengths {"; - LogRange(os, desc.GetLengths(), ", "); - os << "}, "; - - os << "strides {"; - LogRange(os, desc.GetStrides(), ", "); - os << "}" << std::endl; -} - -#if 1 -// FIXME: remove -void bf16_to_f32_(const Tensor& src, Tensor& dst) -{ - for(std::size_t i = 0; i < src.mData.size(); ++i) - dst.mData[i] = ck::type_convert(src.mData[i]); -} -#endif diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 73236b856b..6366a4d6df 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -6,43 +6,45 @@ function(add_instance_library INSTANCE_NAME) endfunction(add_instance_library INSTANCE_NAME) add_subdirectory(gemm) +add_subdirectory(gemm_splitk) add_subdirectory(gemm_bias2d) add_subdirectory(gemm_bias_relu) add_subdirectory(gemm_bias_relu_add) add_subdirectory(gemm_reduce) add_subdirectory(gemm_bias_add_reduce) +add_subdirectory(gemm_add_add_fastgelu) add_subdirectory(batched_gemm) +add_subdirectory(batched_gemm_reduce) +add_subdirectory(grouped_gemm) add_subdirectory(conv1d_fwd) add_subdirectory(conv2d_fwd) add_subdirectory(conv3d_fwd) add_subdirectory(conv2d_fwd_bias_relu) add_subdirectory(conv2d_fwd_bias_relu_add) add_subdirectory(conv2d_bwd_data) -add_subdirectory(reduce) add_subdirectory(convnd_bwd_data) -add_subdirectory(grouped_gemm) add_subdirectory(conv2d_bwd_weight) -add_subdirectory(batched_gemm_reduce) -add_subdirectory(gemm_add_add_fastgelu) +add_subdirectory(reduce) add_library(device_operations STATIC - $ - $ - $ - $ - $ - $ $ + $ $ $ $ - $ - $ - $ - $ - $ - $ $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ ) add_library(composablekernels::device_operations ALIAS device_operations) @@ -67,8 +69,8 @@ target_include_directories(device_operations PUBLIC $ $ $ - $ $ + $ $ ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp index d9422b2f6d..6a262b7929 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp @@ -44,7 +44,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp index d4a2b724fe..15549d8444 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp @@ -44,7 +44,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp index 9e3f8e68c5..ad9c8eff40 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp @@ -48,7 +48,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp index f16c724c71..a5afc76586 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp @@ -49,7 +49,7 @@ using device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp index 057a3f7508..666c64e016 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp @@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp index d35bd6c350..ad97d3530e 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp @@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp index 81b2d23ba6..593903c718 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp @@ -53,7 +53,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp index 3144b4716e..0220919f8e 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp @@ -49,7 +49,7 @@ using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp index 5a323e2928..74e36e9dd2 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp @@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp index f3bac97d93..5873433e2d 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp @@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp index 90ec4bc4d0..14b994e1f6 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp @@ -44,7 +44,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp index 7c8efa0aef..2c656e7ebb 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp @@ -49,7 +49,7 @@ using device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp index de91f25ebe..feef3b48ce 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp @@ -59,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp index 0dd0549dd1..df24ae135d 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp @@ -59,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp index 4b994cc8b0..fb769fc1bb 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp @@ -59,7 +59,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp index ccb3bbd447..389f4225ef 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp @@ -51,7 +51,7 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple< >; void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp index 0ed06bc690..82e230f301 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp @@ -67,9 +67,11 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in >; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( - std::vector< - DeviceGemmReducePtr>& - instances) + std::vector>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp index 5be051225a..16826fdf22 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp @@ -67,9 +67,11 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_in >; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( - std::vector< - DeviceGemmReducePtr>& - instances) + std::vector>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp index 2cc1c85ece..8f2bf3694f 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp @@ -67,9 +67,11 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_in >; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( - std::vector< - DeviceGemmReducePtr>& - instances) + std::vector>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp index f457d5b38f..c2eb10a195 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp @@ -64,9 +64,11 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_in >; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( - std::vector< - DeviceGemmReducePtr>& - instances) + std::vector>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 8de1920bb3..ce66b56a3e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -28,14 +28,6 @@ set(DEVICE_GEMM_INSTANCE_SOURCE device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp; - device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; - device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; - device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; - device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp; - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp; - device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp; - device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp; device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp; device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp; diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt new file mode 100644 index 0000000000..3700ddf19d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -0,0 +1,15 @@ +set(DEVICE_GEMM_SPLITK_INSTANCE_SOURCE + device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; +) + +add_library(device_gemm_splitk_instance OBJECT ${DEVICE_GEMM_SPLITK_INSTANCE_SOURCE}) + +target_compile_features(device_gemm_splitk_instance PUBLIC) +set_target_properties(device_gemm_splitk_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp similarity index 98% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp index b1b6636869..311b8c088e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp @@ -46,7 +46,7 @@ using device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances = std::tuple< >; void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp similarity index 98% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp index f3bd27a24f..657135e295 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp @@ -46,7 +46,7 @@ using device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances = std::tuple< >; void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp similarity index 98% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp index 9032b57a3a..10229534a9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -46,7 +46,7 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< >; void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp similarity index 99% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp index 71a0e4d38b..31bf3233cd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -83,7 +83,7 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< // >; void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp similarity index 98% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp index ac5435b8f3..f3a26d6de8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp @@ -46,7 +46,7 @@ using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple< >; void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp similarity index 98% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp index 83d267edde..381fc1ced5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp @@ -46,7 +46,7 @@ using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple< >; void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp similarity index 99% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp index e4e89c1ddc..47b3f2ebd0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp @@ -51,7 +51,7 @@ using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple< >; void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp similarity index 99% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp index d324a67eb7..d532fe1e77 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp @@ -51,7 +51,7 @@ using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple< >; void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( - std::vector>& instances) + std::vector>& instances) { add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances{}); diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index b48f28a23a..b5d341095b 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -6,6 +6,7 @@ include_directories(BEFORE set(PROFILER_SOURCE src/profiler.cpp src/profile_gemm.cpp + src/profile_gemm_splitk.cpp src/profile_gemm_bias_2d.cpp src/profile_gemm_bias_relu.cpp src/profile_gemm_bias_relu_add.cpp @@ -27,21 +28,22 @@ add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE conv_util) -target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_splitk_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance) target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) +target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) -target_link_libraries(ckProfiler PRIVATE device_reduce_instance) -target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) -target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) -target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance) +target_link_libraries(ckProfiler PRIVATE device_reduce_instance) diff --git a/profiler/include/profile_batched_gemm_impl.hpp b/profiler/include/profile_batched_gemm_impl.hpp index 40dd693d14..21bb1d86a9 100644 --- a/profiler/include/profile_batched_gemm_impl.hpp +++ b/profiler/include/profile_batched_gemm_impl.hpp @@ -7,56 +7,17 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/device_batched_gemm_instance.hpp" + #include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/conv_util.hpp" #include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_batched_gemm_instance { - -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; - -void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances( - std::vector&); -void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances( - std::vector&); -void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances( - std::vector&); -void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances( - std::vector&); -void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(std::vector&); -void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances( - std::vector&); -void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances( - std::vector&); -void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances( - std::vector&); -void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( - std::vector&); - -} // namespace device_batched_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - namespace ck { namespace profiler { @@ -103,27 +64,22 @@ bool profile_batched_gemm_impl(int do_verification, f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); Tensor c_g_m_n_device_result( f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); - std::unique_ptr> c_f32_g_m_n_host_result = nullptr; - std::unique_ptr> c_f32_g_m_n_device_result = nullptr; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; - std::size_t num_thread = 1; switch(init_method) { case 0: break; case 1: - a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: - a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } - // set zero to c_device_buf - c_g_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -135,56 +91,21 @@ bool profile_batched_gemm_impl(int do_verification, if(do_verification) { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - Tensor a_f32_g_m_k( - f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); - Tensor b_f32_g_k_n( - f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); - c_f32_g_m_n_host_result = std::make_unique>( - f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); - c_f32_g_m_n_device_result = std::make_unique>( - f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; - bf16_to_f32_(a_g_m_k, a_f32_g_m_k); - bf16_to_f32_(b_g_k_n, b_f32_g_k_n); + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); - using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: - ReferenceBatchedGemm; + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); - auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; - auto ref_invoker = ref_batched_gemm.MakeInvoker(); - - auto ref_argument = ref_batched_gemm.MakeArgument(a_f32_g_m_k, - b_f32_g_k_n, - *c_f32_g_m_n_host_result, - a_element_op, - b_element_op, - c_element_op); - - ref_invoker.Run(ref_argument); - } - else - { - - using ReferenceBatchedGemmInstance = - ck::tensor_operation::host::ReferenceBatchedGemm; - - auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; - auto ref_invoker = ref_batched_gemm.MakeInvoker(); - - auto ref_argument = ref_batched_gemm.MakeArgument( - a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); - - ref_invoker.Run(ref_argument); - } + ref_invoker.Run(ref_argument); } DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); @@ -195,172 +116,51 @@ bool profile_batched_gemm_impl(int do_verification, b_device_buf.ToDevice(b_g_k_n.mData.data()); c_device_buf.ToDevice(c_g_m_n_device_result.mData.data()); - // add device GEMM instances - std::vector - gemm_ptrs; + // add device op instances + const auto op_ptrs = ck::tensor_operation::device::device_batched_gemm_instance:: + get_device_batched_gemm_instances(); - if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_batched_gemm_instance:: - add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(gemm_ptrs); - } - } - - if(gemm_ptrs.size() <= 0) + if(op_ptrs.size() <= 0) { throw std::runtime_error("wrong! no device GEMM instance found"); } - std::string best_gemm_name; + std::string best_op_name; float best_ave_time = 0; float best_tflops = 0; float best_gb_per_sec = 0; - // profile device GEMM instances - for(auto& gemm_ptr : gemm_ptrs) + // profile device op instances + for(auto& op_ptr : op_ptrs) { auto argument_ptr = - gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - BatchCount); + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + BatchCount); - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + if(op_ptr->IsSupportedArgument(argument_ptr.get())) { - std::string gemm_name = gemm_ptr->GetTypeString(); + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); @@ -376,11 +176,11 @@ bool profile_batched_gemm_impl(int do_verification, float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm_name << std::endl; + << " GB/s, " << op_name << std::endl; if(tflops > best_tflops) { - best_gemm_name = gemm_name; + best_op_name = op_name; best_tflops = tflops; best_ave_time = ave_time; best_gb_per_sec = gb_per_sec; @@ -390,20 +190,8 @@ bool profile_batched_gemm_impl(int do_verification, { c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - - bf16_to_f32_(c_g_m_n_device_result, *c_f32_g_m_n_device_result); - float err = check_error(*c_f32_g_m_n_host_result, *c_f32_g_m_n_device_result); - pass = pass && (err < 1E-6); - } - else - { - float err = check_error(c_g_m_n_host_result, c_g_m_n_device_result); - pass = pass && (err < 1E-6); - } + pass = pass & + ck::utils::check_err(c_g_m_n_device_result.mData, c_g_m_n_host_result.mData); if(do_log) { @@ -419,13 +207,12 @@ bool profile_batched_gemm_impl(int do_verification, } else { - std::cout << "this device GEMM instance does not support this GEMM problem" - << std::endl; + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; } } std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; return pass; } diff --git a/profiler/include/profile_batched_gemm_reduce_impl.hpp b/profiler/include/profile_batched_gemm_reduce_impl.hpp index e3c5a331fa..5b9557f7be 100644 --- a/profiler/include/profile_batched_gemm_reduce_impl.hpp +++ b/profiler/include/profile_batched_gemm_reduce_impl.hpp @@ -6,7 +6,7 @@ #include "ck/ck.hpp" #include "ck/utility/reduction_operator.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/utility/check_err.hpp" @@ -29,7 +29,7 @@ using Square = ck::tensor_operation::element_wise::UnarySquare; using DInElementOps = ck::Tuple; using DOutElementOps = ck::Tuple; -using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< +using DeviceBatchedGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceBatchedGemmReducePtr< ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, @@ -37,16 +37,16 @@ using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePt DOutElementOps>; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( - std::vector&); + std::vector&); void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( - std::vector&); + std::vector&); void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( - std::vector&); + std::vector&); void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( - std::vector&); + std::vector&); } // namespace device_gemm_instance } // namespace device @@ -204,7 +204,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, b_device_buf.ToDevice(b_g_k_n.mData.data()); // add device GEMM instances - std::vector + std::vector gemm_ptrs; if constexpr(is_same::value && is_same::value && diff --git a/profiler/include/profile_convnd_fwd.hpp b/profiler/include/profile_convnd_fwd.hpp deleted file mode 100644 index a0cbd3de28..0000000000 --- a/profiler/include/profile_convnd_fwd.hpp +++ /dev/null @@ -1,12 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -namespace ck { -namespace profiler { - -int profile_convnd_fwd(int argc, char* argv[]); - -} // namespace profiler -} // namespace ck diff --git a/profiler/include/profile_gemm_add_add_fastgelu_impl.hpp b/profiler/include/profile_gemm_add_add_fastgelu_impl.hpp index a32db463b1..a39d55acae 100644 --- a/profiler/include/profile_gemm_add_add_fastgelu_impl.hpp +++ b/profiler/include/profile_gemm_add_add_fastgelu_impl.hpp @@ -9,6 +9,9 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/device_gemm_add_add_fastgelu_instance.hpp" + #include "ck/library/utility/check_err.hpp" #include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/host_tensor.hpp" @@ -16,31 +19,6 @@ #include "ck/library/host_tensor/host_conv.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using DeviceGemmAddAddFastGeluPtr = ck::tensor_operation::device::DeviceGemmMultipleDPtr< - 2, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::AddAddFastGelu>; - -void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( - std::vector&); -void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( - std::vector&); -void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( - std::vector&); -void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( - std::vector&); - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - namespace ck { namespace profiler { @@ -55,18 +33,18 @@ template -int profile_gemm_add_add_fastgelu_impl(int do_verification, - int init_method, - bool /*do_log*/, - bool time_kernel, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideD0, - int StrideD1, - int StrideE) +bool profile_gemm_add_add_fastgelu_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideD1, + int StrideE) { auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { @@ -122,48 +100,21 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification, const auto b_element_op = BElementOp{}; const auto cde_element_op = CDEElementOp{}; - // add device GEMM instances - std::vector - device_op_ptrs; + // add device op instances + const auto op_ptrs = ck::tensor_operation::device::device_gemm_instance:: + get_device_gemm_add_add_fastgelu_instances(); - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(is_same_v && - is_same_v && - is_same_v) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( - device_op_ptrs); - } - else if constexpr(is_same_v && - is_same_v && - is_same_v) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( - device_op_ptrs); - } - else if constexpr(is_same_v && - is_same_v && - is_same_v) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( - device_op_ptrs); - } - else if constexpr(is_same_v && - is_same_v && - is_same_v) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( - device_op_ptrs); - } - } - - std::cout << "found " << device_op_ptrs.size() << " instances" << std::endl; + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; // run reference if(do_verification) @@ -207,7 +158,7 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification, d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); d1_m_n_device_buf.ToDevice(d1_m_n.mData.data()); - std::string best_device_op_name; + std::string best_op_name; float best_ave_time = 0; float best_tflops = 0; float best_gb_per_sec = 0; @@ -215,14 +166,14 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification, bool pass = true; // profile device operation instances - for(auto& device_op_ptr : device_op_ptrs) + for(auto& op_ptr : op_ptrs) { - auto argument_ptr = device_op_ptr->MakeArgumentPointer( + auto argument_ptr = op_ptr->MakeArgumentPointer( a_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(), std::array{d0_m_n_device_buf.GetDeviceBuffer(), d1_m_n_device_buf.GetDeviceBuffer()}, - static_cast(e_device_buf.GetDeviceBuffer()), + e_device_buf.GetDeviceBuffer(), M, N, K, @@ -234,11 +185,11 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification, b_element_op, cde_element_op); - auto invoker_ptr = device_op_ptr->MakeInvokerPointer(); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); - std::string device_op_name = device_op_ptr->GetTypeString(); + std::string op_name = op_ptr->GetTypeString(); - if(device_op_ptr->IsSupportedArgument(argument_ptr.get())) + if(op_ptr->IsSupportedArgument(argument_ptr.get())) { // re-init E to zero before profiling a kernel e_device_buf.SetZero(); @@ -256,14 +207,14 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification, float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << device_op_name << std::endl; + << gb_per_sec << " GB/s, " << op_name << std::endl; if(tflops > best_tflops) { - best_device_op_name = device_op_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; } if(do_verification) @@ -276,14 +227,14 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification, } else { - std::cout << device_op_name << " does not support this problem" << std::endl; + std::cout << op_name << " does not support this problem" << std::endl; } } std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_device_op_name << std::endl; + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - return pass ? 0 : 1; + return pass; } } // namespace profiler diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index 792a04516c..2122010c7f 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -12,112 +12,37 @@ #include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/device_gemm_instance.hpp" + #include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/conv_util.hpp" #include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; - -void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector&); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( - std::vector&); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); - -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector&); - -void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( - std::vector&); - -void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector&); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector&); - -void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector&); - -void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector&); - -void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector&); -void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector&); - -void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); -void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); -void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); - -void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector&); -void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector&); -void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector&); -void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector&); - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - namespace ck { namespace profiler { template -void profile_gemm_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideC, - int KBatch) +int profile_gemm_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC) { + bool pass = true; + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(is_same::value) @@ -134,32 +59,25 @@ void profile_gemm_impl(int do_verification, Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; - std::size_t num_thread = 1; switch(init_method) { - // case 0: break; - case 0: - a_m_k.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - b_k_n.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; + case 0: break; case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } - // set zero to c_device_buf - c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); - using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -176,303 +94,65 @@ void profile_gemm_impl(int do_verification, b_device_buf.ToDevice(b_k_n.mData.data()); c_device_buf.ToDevice(c_m_n_device_result.mData.data()); - // add device GEMM instances - std::vector gemm_ptrs; + // add device op instances + const auto op_ptrs = ck::tensor_operation::device::device_gemm_instance:: + get_device_gemm_instances(); - if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); - } - } - } - else if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if(KBatch > 1) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); - } - else - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); - } - } - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(gemm_ptrs); - } - } - else if constexpr(is_same::value && is_same::value && - is_same::value) - { - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemm_ptrs); - } - else if constexpr(is_same::value && - is_same::value && - is_same::value) - { - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemm_ptrs); - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemm_ptrs); - } - } - - if(gemm_ptrs.size() <= 0) + if(op_ptrs.size() <= 0) { throw std::runtime_error("wrong! no device GEMM instance found"); } - std::string best_gemm_name; + // Run reference GEMM + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_op = ReferenceGemmInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = ref_op.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; float best_ave_time = 0; float best_tflops = 0; float best_gb_per_sec = 0; // profile device GEMM instances - for(auto& gemm_ptr : gemm_ptrs) + for(auto& op_ptr : op_ptrs) { auto argument_ptr = - gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - KBatch); + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + if(op_ptr->IsSupportedArgument(argument_ptr.get())) { // re-init C to zero before profiling next kernel - c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); - c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + c_device_buf.SetZero(); - std::string gemm_name = gemm_ptr->GetTypeString(); + std::string op_name = op_ptr->GetTypeString(); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); @@ -487,11 +167,11 @@ void profile_gemm_impl(int do_verification, float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << gemm_name << std::endl; + << gb_per_sec << " GB/s, " << op_name << std::endl; if(tflops > best_tflops) { - best_gemm_name = gemm_name; + best_op_name = op_name; best_tflops = tflops; best_ave_time = ave_time; best_gb_per_sec = gb_per_sec; @@ -501,86 +181,15 @@ void profile_gemm_impl(int do_verification, { c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - if constexpr(is_same::value && - is_same::value && - is_same::value) - { - Tensor a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result( - f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_f32_result( - f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - bf16_to_f32_(a_m_k, a_f32_m_k); - bf16_to_f32_(b_k_n, b_f32_k_n); - bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result); - - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemm; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_f32_m_k, - b_f32_k_n, - c_m_n_host_result, - a_element_op, - b_element_op, - c_element_op); - - ref_invoker.Run(ref_argument); - - ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData); - - if(do_log) - { - LogRangeAsType( - std::cout << "c_host : ", c_m_n_host_result.mData, ",") - << std::endl; - } - } - else - { - Tensor c_m_n_host_result( - f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemm; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); - - ref_invoker.Run(ref_argument); - ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); - - if(do_log) - { - LogRangeAsType( - std::cout << "c_host : ", c_m_n_host_result.mData, ",") - << std::endl; - } - } + pass = + pass & ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); if(do_log) { LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") << std::endl; } @@ -588,8 +197,7 @@ void profile_gemm_impl(int do_verification, } else { - std::cout << gemm_ptr->GetTypeString() << " does not support this GEMM problem" - << std::endl; + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; } } @@ -631,7 +239,9 @@ void profile_gemm_impl(int do_verification, std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " - << best_gemm_name << std::endl; + << best_op_name << std::endl; + + return pass ? 0 : 1; } } // namespace profiler diff --git a/profiler/include/profile_gemm_splitk_impl.hpp b/profiler/include/profile_gemm_splitk_impl.hpp new file mode 100644 index 0000000000..608c53af45 --- /dev/null +++ b/profiler/include/profile_gemm_splitk_impl.hpp @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/device_gemm_splitk_instance.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/host_tensor/device_memory.hpp" +#include "ck/library/host_tensor/host_tensor.hpp" +#include "ck/library/host_tensor/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_splitk_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int KBatch) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + // add device op instances + const auto op_ptrs = + ck::tensor_operation::device::device_gemm_instance::get_device_gemm_splitk_instances< + ADataType, + BDataType, + CDataType, + ALayout, + BLayout, + CLayout>(); + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device operation instance found"); + } + + // Run reference GEMM + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + KBatch); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = + pass & ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time + << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/profile_batched_gemm.cpp b/profiler/src/profile_batched_gemm.cpp index bf3b4eb5cd..45ec352e72 100644 --- a/profiler/src/profile_batched_gemm.cpp +++ b/profiler/src/profile_batched_gemm.cpp @@ -15,10 +15,6 @@ enum struct GemmMatrixLayout MK_NK_MN, // 1 KM_KN_MN, // 2 KM_NK_MN, // 3 - MK_KN_NM, // 4 - MK_NK_NM, // 5 - KM_KN_NM, // 6 - KM_NK_NM, // 7 }; enum struct GemmDataType @@ -31,7 +27,7 @@ enum struct GemmDataType int profile_batched_gemm(int argc, char* argv[]) { - if(!(argc == 15)) + if(argc != 15) { printf("arg1: tensor operation (batched_gemm: Batched GEMM)\n"); printf("arg2: data type (0: fp32; 1: fp16, 2: bf16, 3: int8)\n"); @@ -64,330 +60,117 @@ int profile_batched_gemm(int argc, char* argv[]) const int BatchCount = std::stoi(argv[14]); - if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using INT8 = int8_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = ck::profiler:: + profile_batched_gemm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC, + BatchCount); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) - { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) - { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) - { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); - } - else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) - { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); - } - else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) - { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); - } - else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) - { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); - } - else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) - { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); - } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) - { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(F32{}, F32{}, F32{}, Row{}, Row{}, Row{}); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(F32{}, F32{}, F32{}, Row{}, Col{}, Row{}); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(F32{}, F32{}, F32{}, Col{}, Row{}, Row{}); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(F32{}, F32{}, F32{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F16{}, F16{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(BF16{}, BF16{}, BF16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(BF16{}, BF16{}, BF16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(BF16{}, BF16{}, BF16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(BF16{}, BF16{}, BF16{}, Col{}, Col{}, Row{}); } else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{}); } else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(INT8{}, INT8{}, INT8{}, Row{}, Col{}, Row{}); } else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(INT8{}, INT8{}, INT8{}, Col{}, Row{}, Row{}); } else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) { - ck::profiler::profile_batched_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - BatchCount); + return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{}); } else { - throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); - } + std::cout << "this data_type & layout is not implemented" << std::endl; - return 0; + return 1; + } } diff --git a/profiler/src/profile_convnd_fwd.cpp b/profiler/src/profile_convnd_fwd.cpp index f81fcd9b69..8223be160e 100644 --- a/profiler/src/profile_convnd_fwd.cpp +++ b/profiler/src/profile_convnd_fwd.cpp @@ -10,11 +10,10 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + #include "ck/library/utility/conv_util.hpp" #include "ck/library/utility/fill.hpp" -#include "profiler/include/profile_convnd_fwd.hpp" - namespace { enum struct ConvDataType @@ -304,7 +303,7 @@ void profile_convnd_instances(ConvDataType data_type, } // namespace -int ck::profiler::profile_convnd_fwd(int argc, char* argv[]) +int profile_convnd_fwd(int argc, char* argv[]) { using namespace ck::utils::conv; diff --git a/profiler/src/profile_gemm.cpp b/profiler/src/profile_gemm.cpp index 891c764183..624f3dbf61 100644 --- a/profiler/src/profile_gemm.cpp +++ b/profiler/src/profile_gemm.cpp @@ -14,10 +14,6 @@ enum struct GemmMatrixLayout MK_NK_MN, // 1 KM_KN_MN, // 2 KM_NK_MN, // 3 - MK_KN_NM, // 4 - MK_NK_NM, // 5 - KM_KN_NM, // 6 - KM_NK_NM, // 7 }; enum struct GemmDataType @@ -30,7 +26,7 @@ enum struct GemmDataType int profile_gemm(int argc, char* argv[]) { - if(!(argc == 14 || argc == 15)) + if(argc != 14) { printf("arg1: tensor operation (gemm: GEMM)\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); @@ -41,9 +37,8 @@ int profile_gemm(int argc, char* argv[]) printf("arg4: verification (0: no; 1: yes)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg6: print tensor value (0: no; 1: yes)\n"); - printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); - printf("arg14: split k into mulitiple batch\n"); exit(1); } @@ -61,350 +56,125 @@ int profile_gemm(int argc, char* argv[]) const int StrideA = std::stoi(argv[11]); const int StrideB = std::stoi(argv[12]); const int StrideC = std::stoi(argv[13]); - int KBatch = 1; - if(argc == 15) - KBatch = std::stoi(argv[14]); - if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using INT8 = int8_t; + using INT32 = int32_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = + ck::profiler::profile_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) - { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) - { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); - } - else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) - { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); - } - else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) - { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}); } - else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); } - else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); } - else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}); } - else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); } else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? K : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Row{}, Row{}); } else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(BF16{}, BF16{}, F32{}, BF16{}, Row{}, Col{}, Row{}); } else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{}); } else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN) { - ck::profiler::profile_gemm_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? M : StrideA, - (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC, - KBatch); + return profile(BF16{}, BF16{}, F32{}, BF16{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(INT8{}, INT8{}, INT32{}, INT8{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(INT8{}, INT8{}, INT32{}, INT8{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(INT8{}, INT8{}, INT32{}, INT8{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(INT8{}, INT8{}, INT32{}, INT8{}, Col{}, Col{}, Row{}); } else { - throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); - } + std::cout << "this data_type & layout is not implemented" << std::endl; - return 0; + return 1; + } } diff --git a/profiler/src/profile_gemm_add_add_fastgelu.cpp b/profiler/src/profile_gemm_add_add_fastgelu.cpp index d0a9da2bda..c4c770c293 100644 --- a/profiler/src/profile_gemm_add_add_fastgelu.cpp +++ b/profiler/src/profile_gemm_add_add_fastgelu.cpp @@ -16,10 +16,6 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) MK_NK_MN_MN_MN, // 1 KM_KN_MN_MN_MN, // 2 KM_NK_MN_MN_MN, // 3 - MK_KN_NM_MN_MN, // 4 - MK_NK_NM_MN_MN, // 5 - KM_KN_NM_MN_MN, // 6 - KM_NK_NM_MN_MN, // 7 }; enum struct MatrixDataType @@ -101,17 +97,17 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) const int DefaultStrideD1 = ck::is_same_v ? N : M; const int DefaultStrideE = ck::is_same_v ? N : M; - return ck::profiler::profile_gemm_add_add_fastgelu_impl( + bool pass = ck::profiler::profile_gemm_add_add_fastgelu_impl( do_verification, init_method, do_log, @@ -124,6 +120,8 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) (StrideD0 < 0) ? DefaultStrideD0 : StrideD0, (StrideD1 < 0) ? DefaultStrideD1 : StrideD1, (StrideE < 0) ? DefaultStrideE : StrideE); + + return pass ? 0 : 1; }; if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN) @@ -149,6 +147,6 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) { std::cout << "this data_type & layout is not implemented" << std::endl; - return 0; + return 1; } } diff --git a/profiler/src/profile_gemm_splitk.cpp b/profiler/src/profile_gemm_splitk.cpp new file mode 100644 index 0000000000..fff023c8e0 --- /dev/null +++ b/profiler/src/profile_gemm_splitk.cpp @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/include/profile_gemm_splitk_impl.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +int profile_gemm_splitk(int argc, char* argv[]) +{ + if(argc != 15) + { + printf("arg1: tensor operation (gemm_splitk: Split-K GEMM)\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: split k into mulitiple batch\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const int KBatch = std::stoi(argv[14]); + + using F32 = float; + using F16 = ck::half_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_splitk_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC, + KBatch); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index d21d243607..e30d921da2 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -1,49 +1,47 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include #include -#include "profiler/include/profile_convnd_fwd.hpp" - int profile_gemm(int, char*[]); +int profile_gemm_splitk(int, char*[]); int profile_gemm_bias_2d(int, char*[]); int profile_gemm_bias_relu(int, char*[]); int profile_gemm_bias_relu_add(int, char*[]); -int profile_gemm_reduce(int, char*[]); int profile_gemm_bias_add_reduce(int, char*[]); +int profile_gemm_add_add_fastgelu(int, char*[]); +int profile_gemm_reduce(int, char*[]); int profile_batched_gemm(int, char*[]); +int profile_batched_gemm_reduce(int, char*[]); int profile_grouped_gemm(int, char*[]); int profile_conv_fwd(int, char*[]); int profile_conv_fwd_bias_relu(int, char*[]); int profile_conv_fwd_bias_relu_add(int, char*[]); +int profile_convnd_fwd(int argc, char* argv[]); int profile_convnd_bwd_data(int, char*[], int); -int profile_reduce(int, char*[]); int profile_conv_bwd_weight(int, char*[]); -int profile_batched_gemm_reduce(int, char*[]); -int profile_gemm_add_add_fastgelu(int, char*[]); +int profile_reduce(int, char*[]); static void print_helper_message() { // clang-format off - printf("arg1: tensor operation (gemm: GEMM\n" - " gemm_bias_2d: GEMM+Bias(2D)\n" - " gemm_bias_relu: GEMM+Bias+ReLU\n" - " gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n" - " gemm_reduce: GEMM+Reduce\n" - " grouped_gemm: Grouped GEMM\n" - " conv_fwd: ForwardConvolution\n" - " conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n" - " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n" - " conv1d_bwd_data: BackwardConvolution data 1 dim\n" - " conv2d_bwd_data: BackwardConvolution data 2 dim\n" - " conv3d_bwd_data: BackwardConvolution data 3 dim\n" - " reduce: Reduce\n" - " conv2d_bwd_weight: Backward Weight Convolution 2d\n" - " gemm_add_add_fastgelu: GEMM+Add+Add+FastGeLU\n"); + printf("arg1: tensor operation (gemm: GEMM\n" + " gemm_splitk: Split-K GEMM\n" + " gemm_bias_2d: GEMM+Bias(2D)\n" + " gemm_bias_relu: GEMM+Bias+ReLU\n" + " gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n" + " gemm_add_add_fastgelu: GEMM+Add+Add+FastGeLU\n" + " gemm_reduce: GEMM+Reduce\n" + " batched_gemm: Batched GEMM\n" + " grouped_gemm: Grouped GEMM\n" + " conv_fwd: ForwardConvolution\n" + " conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n" + " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n" + " conv1d_bwd_data: BackwardConvolution data 1 dim\n" + " conv2d_bwd_data: BackwardConvolution data 2 dim\n" + " conv3d_bwd_data: BackwardConvolution data 3 dim\n" + " conv2d_bwd_weight: Backward Weight Convolution 2d\n" + " reduce: Reduce\n"); // clang-format on } @@ -60,6 +58,10 @@ int main(int argc, char* argv[]) { return profile_gemm(argc, argv); } + else if(strcmp(argv[1], "gemm_splitk") == 0) + { + return profile_gemm_splitk(argc, argv); + } else if(strcmp(argv[1], "gemm_bias_2d") == 0) { return profile_gemm_bias_2d(argc, argv); @@ -94,7 +96,7 @@ int main(int argc, char* argv[]) } else if(strcmp(argv[1], "conv_fwd") == 0) { - return ck::profiler::profile_convnd_fwd(argc, argv); + return profile_convnd_fwd(argc, argv); } else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0) { diff --git a/test/batched_gemm/batched_gemm_util.hpp b/test/batched_gemm/batched_gemm_util.hpp deleted file mode 100644 index ffc46133b8..0000000000 --- a/test/batched_gemm/batched_gemm_util.hpp +++ /dev/null @@ -1,109 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef BATCHED_GEMM_UTILS_HPP -#define BATCHED_GEMM_UTILS_HPP - -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" - -namespace ck { -namespace batched_gemm_util { - -struct GemmParams -{ - GemmParams() - : M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0) - { - } - - ck::index_t M; - ck::index_t N; - ck::index_t K; - - ck::index_t StrideA; - ck::index_t StrideB; - ck::index_t StrideC; - - float alpha; - float beta; -}; - -template -void RunHostBatchedGemm(const Tensor& A, - const Tensor& B, - Tensor& C, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) -{ - auto ref_batched_gemm = BatchedGemmInstance{}; - auto ref_invoker = ref_batched_gemm.MakeInvoker(); - - auto ref_argument = - ref_batched_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op); - - ref_invoker.Run(ref_argument); -} - -template -void RunDeviceBatchedGemm(DeviceGemmPtr& batched_gemm_ptr, - const ck::batched_gemm_util::GemmParams& params, - const Tensor& A, - const Tensor& B, - Tensor& C, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) -{ - DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace()); - DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); - DeviceMem c_g_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); - - a_g_m_k_device_buf.ToDevice(A.mData.data()); - b_g_k_n_device_buf.ToDevice(B.mData.data()); - - const auto batch_count = A.mDesc.GetLengths()[0]; - auto invoker_ptr = batched_gemm_ptr->MakeInvokerPointer(); - auto argument_ptr = batched_gemm_ptr->MakeArgumentPointer( - static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_g_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_g_m_n_device_buf.GetDeviceBuffer()), - params.M, - params.N, - params.K, - params.StrideA, - params.StrideB, - params.StrideC, - a_element_op, - b_element_op, - c_element_op, - batch_count); - - if(!batched_gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - invoker_ptr->Run(argument_ptr.get()); - c_g_m_n_device_buf.FromDevice(C.mData.data()); -} - -} // namespace batched_gemm_util -} // namespace ck -#endif diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp index b3cb710d1c..7af3799e7e 100644 --- a/test/gemm/gemm_util.hpp +++ b/test/gemm/gemm_util.hpp @@ -214,6 +214,11 @@ struct TestGemm res = ck::utils::check_err(c_device.mData, c_host.mData); std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } else if(std::is_same::value) { res = ck::utils::check_err(c_device.mData, c_host.mData); @@ -234,121 +239,5 @@ struct TestGemm } }; -template -struct TestGemmBF16 -{ - using BF16 = ck::bhalf_t; - - auto PrepareGemmTensorBF16(const ck::gemm_util::GemmParams& params) - { - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - // use fp32 host kernel to verify bf16 device kernel - Tensor a_m_k_bf16( - f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); - Tensor b_k_n_bf16( - f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); - Tensor c_m_n_device_bf16( - f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - - Tensor a_m_k_fp32( - f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); - Tensor b_k_n_fp32( - f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); - Tensor c_m_n_host_fp32( - f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - Tensor c_m_n_device_fp32( - f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); - - a_m_k_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - b_k_n_bf16.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - - bf16_to_f32_(a_m_k_bf16, a_m_k_fp32); - bf16_to_f32_(b_k_n_bf16, b_k_n_fp32); - - return std::make_tuple(a_m_k_bf16, - b_k_n_bf16, - c_m_n_device_bf16, - a_m_k_fp32, - b_k_n_fp32, - c_m_n_host_fp32, - c_m_n_device_fp32); - } - - auto operator()(DeviceGemmPtr_& gemmPtr) - { - // Arrange - ck::gemm_util::GemmParams params; - params.M = 1024; - params.N = 1024; - params.K = 1024; - params.StrideA = 1024; - params.StrideB = 1024; - params.StrideC = 1024; - - auto host_tensors = PrepareGemmTensorBF16(params); - const Tensor& a_bf16 = std::get<0>(host_tensors); - const Tensor& b_bf16 = std::get<1>(host_tensors); - Tensor& c_device_bf16 = std::get<2>(host_tensors); - Tensor& a_fp32 = std::get<3>(host_tensors); - Tensor& b_fp32 = std::get<4>(host_tensors); - Tensor& c_host_fp32 = std::get<5>(host_tensors); - Tensor& c_device_fp32 = std::get<6>(host_tensors); - - auto a_element_op = AElementwiseOperation{}; - auto b_element_op = BElementwiseOperation{}; - auto c_element_op = CElementwiseOperation{}; - - // use fp32 host kernel to verify bf16 device kernel - using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemm; - ck::gemm_util::RunHostGEMM( - a_fp32, b_fp32, c_host_fp32, a_element_op, b_element_op, c_element_op); - - // Act - ck::gemm_util::RunDeviceGEMM(gemmPtr, - params, - a_bf16, - b_bf16, - c_device_bf16, - a_element_op, - b_element_op, - c_element_op); - - bf16_to_f32_(c_device_bf16, c_device_fp32); - - // Assert - bool res = ck::utils::check_err( - c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - - return res; - }; -}; - } // namespace gemm_util } // namespace ck diff --git a/test/gemm/gemm_xdl_bf16.cpp b/test/gemm/gemm_xdl_bf16.cpp index 2b3bd7c98d..415141c2cc 100644 --- a/test/gemm/gemm_xdl_bf16.cpp +++ b/test/gemm/gemm_xdl_bf16.cpp @@ -47,6 +47,11 @@ void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( int main() { + using ADataType = ck::bhalf_t; + using BDataType = ck::bhalf_t; + using CDataType = ck::bhalf_t; + using AccDataType = float; + using RowMajor = ck::tensor_layout::gemm::RowMajor; using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; @@ -58,13 +63,17 @@ int main() for(auto& gemmPtr : gemmPtrs) { - res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); + res &= ck::gemm_util::TestGemm{}(gemmPtr); } gemmPtrs.clear(); @@ -73,13 +82,17 @@ int main() for(auto& gemmPtr : gemmPtrs) { - res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); + res &= ck::gemm_util::TestGemm{}(gemmPtr); } gemmPtrs.clear(); @@ -88,13 +101,17 @@ int main() for(auto& gemmPtr : gemmPtrs) { - res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); + res &= ck::gemm_util::TestGemm{}(gemmPtr); } gemmPtrs.clear(); @@ -103,13 +120,17 @@ int main() for(auto& gemmPtr : gemmPtrs) { - res &= ck::gemm_util::TestGemmBF16{}(gemmPtr); + res &= ck::gemm_util::TestGemm{}(gemmPtr); } std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; diff --git a/test/gemm/gemm_xdl_fp16.cpp b/test/gemm/gemm_xdl_fp16.cpp index 9035eb4241..fac4d346df 100644 --- a/test/gemm/gemm_xdl_fp16.cpp +++ b/test/gemm/gemm_xdl_fp16.cpp @@ -38,10 +38,12 @@ void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector&); void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector&); +#if 0 void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector&); +#endif void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); @@ -69,8 +71,10 @@ int main() std::vector gemmPtrs; ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemmPtrs); +#if 0 ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemmPtrs); +#endif ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemmPtrs); @@ -92,8 +96,10 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemmPtrs); +#if 0 ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemmPtrs); +#endif ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemmPtrs); @@ -115,8 +121,10 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); +#if 0 ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); +#endif ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); @@ -138,8 +146,10 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); +#if 0 ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); +#endif ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); ck::tensor_operation::device::device_gemm_instance:: diff --git a/test/gemm/gemm_xdl_fp32.cpp b/test/gemm/gemm_xdl_fp32.cpp index a3787bcdde..0a83782629 100644 --- a/test/gemm/gemm_xdl_fp32.cpp +++ b/test/gemm/gemm_xdl_fp32.cpp @@ -38,10 +38,12 @@ void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector&); void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector&); +#if 0 void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); +#endif void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector&); void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector&); @@ -67,8 +69,10 @@ int main() std::vector gemmPtrs; ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemmPtrs); +#if 0 ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemmPtrs); +#endif ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemmPtrs); @@ -90,8 +94,10 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemmPtrs); +#if 0 ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemmPtrs); +#endif ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemmPtrs); @@ -113,8 +119,10 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); +#if 0 ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); +#endif ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); @@ -136,8 +144,10 @@ int main() gemmPtrs.clear(); ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); +#if 0 ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); +#endif ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); diff --git a/test/gemm_split_k/CMakeLists.txt b/test/gemm_split_k/CMakeLists.txt index 40d422377b..ab1d016c9d 100644 --- a/test/gemm_split_k/CMakeLists.txt +++ b/test/gemm_split_k/CMakeLists.txt @@ -1,3 +1,3 @@ add_test_executable(test_gemm_split_k gemm_split_k.cpp) target_link_libraries(test_gemm_split_k PRIVATE host_tensor) -target_link_libraries(test_gemm_split_k PRIVATE device_gemm_instance) +target_link_libraries(test_gemm_split_k PRIVATE device_gemm_splitk_instance) diff --git a/test/gemm_split_k/gemm_split_k.cpp b/test/gemm_split_k/gemm_split_k.cpp index d21d35ec25..ed732b09c3 100644 --- a/test/gemm_split_k/gemm_split_k.cpp +++ b/test/gemm_split_k/gemm_split_k.cpp @@ -15,7 +15,6 @@ #include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/host_tensor/host_tensor_generator.hpp" -#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/host_tensor/host_gemm.hpp" @@ -28,20 +27,24 @@ enum struct GemmMatrixLayout KM_NK_MN, // 3 }; -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; +using DeviceGemmSplitKNoOpPtr = ck::tensor_operation::device::DeviceGemmSplitKPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; namespace ck { namespace tensor_operation { namespace device { namespace device_gemm_instance { -void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances( + std::vector&); } // namespace device_gemm_instance } // namespace device @@ -150,7 +153,7 @@ int test_gemm(const gemmArgs& args) c_device_buf.ToDevice(c_m_n_device_result.mData.data()); // add device GEMM instances - std::vector gemm_ptrs; + std::vector gemm_ptrs; if(args.layout == GemmMatrixLayout::MK_KN_MN) {