From ba9c637c0beebcc34b8dc7eeee0626512d7f6b41 Mon Sep 17 00:00:00 2001 From: apoorva Date: Wed, 2 Jul 2025 14:12:52 +0000 Subject: [PATCH] Added examples for gemm_add_relu --- example/69_gemm_add_relu/CMakeLists.txt | 21 +++ example/69_gemm_add_relu/common.hpp | 114 ++++++++++++++ .../gemm_add_relu_wmma_bf16.cpp | 72 +++++++++ .../gemm_add_relu_wmma_fp16.cpp | 72 +++++++++ .../gemm_add_relu_wmma_v3_bf16.cpp | 78 ++++++++++ .../gemm_add_relu_wmma_v3_fp16.cpp | 76 +++++++++ .../gemm_add_relu_xdl_bf16.cpp | 82 ++++++++++ .../gemm_add_relu_xdl_fp16.cpp | 82 ++++++++++ .../run_gem_add_relu_example.inc | 144 +++++++++++++++++ .../run_gemm_add_relu_example_v3.inc | 145 ++++++++++++++++++ 10 files changed, 886 insertions(+) create mode 100644 example/69_gemm_add_relu/CMakeLists.txt create mode 100644 example/69_gemm_add_relu/common.hpp create mode 100644 example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp create mode 100644 example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp create mode 100644 example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp create mode 100644 example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp create mode 100644 example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp create mode 100644 example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp create mode 100644 example/69_gemm_add_relu/run_gem_add_relu_example.inc create mode 100644 example/69_gemm_add_relu/run_gemm_add_relu_example_v3.inc diff --git a/example/69_gemm_add_relu/CMakeLists.txt b/example/69_gemm_add_relu/CMakeLists.txt new file mode 100644 index 0000000000..936e9acea3 --- /dev/null +++ b/example/69_gemm_add_relu/CMakeLists.txt @@ -0,0 +1,21 @@ +add_custom_target(example_gemm_add_relu_xdl) + +add_library(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp) +add_example_executable(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp) + +add_library(example_gemm_add_relu_xdl_bf16 gemm_add_relu_xdl_bf16.cpp) +add_example_executable(example_gemm_add_relu_xdl_bf16 gemm_add_relu_xdl_bf16.cpp) + + +add_custom_target(example_gemm_add_relu_wmma) +add_example_executable(example_gemm_add_relu_wmma_bf16 gemm_add_relu_wmma_bf16.cpp) + +add_example_executable(example_gemm_add_relu_wmma_fp16 gemm_add_relu_wmma_fp16.cpp) + +add_example_executable(example_gemm_add_relu_wmma_v3_fp16 gemm_add_relu_wmma_v3_fp16.cpp) +add_example_executable(example_gemm_add_relu_wmma_v3_bf16 gemm_add_relu_wmma_v3_bf16.cpp) + + + + + diff --git a/example/69_gemm_add_relu/common.hpp b/example/69_gemm_add_relu/common.hpp new file mode 100644 index 0000000000..151653e515 --- /dev/null +++ b/example/69_gemm_add_relu/common.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Row_Tuple = ck::Tuple; +using F16_Tuple = ck::Tuple; +using BF16_Tuple = ck::Tuple; + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; +}; +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +inline bool +parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideD = std::stoi(argv[9]); + problem_size.StrideE = std::stoi(argv[10]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD," + "StrideE" + << std::endl; + return false; + } + + return true; +} diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp new file mode 100644 index 0000000000..34b791573a --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = BF16; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 2, // Prefetch stage + 128, // BlockSize + 128, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; + +// clang-format on + +#include "run_gem_add_relu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp new file mode 100644 index 0000000000..8459e67cb4 --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 2, // Prefetch stage + 128, // BlockSize + 128, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; + +// clang-format on + +#include "run_gem_add_relu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp new file mode 100644 index 0000000000..84a40ab3f5 --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_bf16.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = BF16; +using DsDataType = BF16_Tuple; +using EDataType = BF16; + +using Row_Tuple = ck::Tuple; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using DsLayout = Row_Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3< + Row, + Row, + Row_Tuple, + Row, + BF16, + BF16, + BF16_Tuple, + BF16, + F32, + F32, + PassThrough, + PassThrough, + Add, + GemmSpec, + 128, + 128, + 64, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<4, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 4>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1>; + +// clang-format on + +#include "run_gemm_add_relu_example_v3.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp b/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp new file mode 100644 index 0000000000..cca26d6212 --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_wmma_v3_fp16.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = F16_Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using DsLayout = Row_Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3< + Row, + Row, + Row_Tuple, + Row, + F16, + F16, + F16_Tuple, + F16, + F32, + F32, + PassThrough, + PassThrough, + Add, + GemmSpec, + 128, + 128, + 64, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<4, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 4>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1>; + +// clang-format on + +#include "run_gemm_add_relu_example_v3.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp b/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp new file mode 100644 index 0000000000..824b1c2f10 --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = BF16; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +#include "run_gem_add_relu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); } diff --git a/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp b/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp new file mode 100644 index 0000000000..ef8c4cdcf8 --- /dev/null +++ b/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +#include "run_gem_add_relu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); } diff --git a/example/69_gemm_add_relu/run_gem_add_relu_example.inc b/example/69_gemm_add_relu/run_gem_add_relu_example.inc new file mode 100644 index 0000000000..9d17f7863a --- /dev/null +++ b/example/69_gemm_add_relu/run_gem_add_relu_example.inc @@ -0,0 +1,144 @@ +#pragma once + +bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << device_op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(config.do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} + +bool run_gemm_add_relu_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || + run_gemm_add_relu(problem_size, config); +} diff --git a/example/69_gemm_add_relu/run_gemm_add_relu_example_v3.inc b/example/69_gemm_add_relu/run_gemm_add_relu_example_v3.inc new file mode 100644 index 0000000000..3c787421eb --- /dev/null +++ b/example/69_gemm_add_relu/run_gemm_add_relu_example_v3.inc @@ -0,0 +1,145 @@ +#pragma once + +bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << device_op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(config.do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} + +bool run_gemm_add_relu_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || + run_gemm_add_relu(problem_size, config); +}