From f7288bc2b1126f18f684fa9ede51839d2fd32ccf Mon Sep 17 00:00:00 2001 From: "Po-Yen, Chen" Date: Fri, 19 Aug 2022 14:47:09 -0400 Subject: [PATCH] Reuse same implementation code for most of GEMM examples --- example/01_gemm/gemm_dl_int4.cpp | 5 +- example/01_gemm/gemm_xdl_int4.cpp | 5 +- example/01_gemm/run_gemm_example.inc | 35 ++++--- example/01_gemm/run_gemm_int4_example.inc | 121 ---------------------- 4 files changed, 29 insertions(+), 137 deletions(-) delete mode 100644 example/01_gemm/run_gemm_int4_example.inc diff --git a/example/01_gemm/gemm_dl_int4.cpp b/example/01_gemm/gemm_dl_int4.cpp index 98a5db52d3..ea45f21665 100644 --- a/example/01_gemm/gemm_dl_int4.cpp +++ b/example/01_gemm/gemm_dl_int4.cpp @@ -39,6 +39,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; -#include "run_gemm_int4_example.inc" +#define BUILD_INT4_EXAMPLE +#include "run_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_gemm_int4_example(argc, argv); } +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_int4.cpp b/example/01_gemm/gemm_xdl_int4.cpp index f88f1080e3..b2c40900c2 100644 --- a/example/01_gemm/gemm_xdl_int4.cpp +++ b/example/01_gemm/gemm_xdl_int4.cpp @@ -40,6 +40,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; -#include "run_gemm_int4_example.inc" +#define BUILD_INT4_EXAMPLE +#include "run_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_gemm_int4_example(argc, argv); } +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index c7509f722d..e0c0e69daf 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -5,6 +5,10 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) { +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + using namespace ck::literals; auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size; @@ -59,18 +63,25 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) // do GEMM auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op); + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); if(!gemm.IsSupportedArgument(argument)) { diff --git a/example/01_gemm/run_gemm_int4_example.inc b/example/01_gemm/run_gemm_int4_example.inc deleted file mode 100644 index 9b5a9e961c..0000000000 --- a/example/01_gemm/run_gemm_int4_example.inc +++ /dev/null @@ -1,121 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -bool run_gemm_int4(const ProblemSize& problem_size, const ExecutionConfig& config) -{ - using namespace ck::literals; - - auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if constexpr(std::is_same_v) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - - switch(config.init_method) - { - case 0: break; - case 1: - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k.begin(), - a_m_k.end()); - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n.begin(), - b_k_n.end()); - break; - default: - ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k.begin(), a_m_k.end()); - ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n.begin(), b_k_n.end()); - } - - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - - DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * - c_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CElementOp{}; - - // do GEMM - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - auto argument = - gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op); - - if(!gemm.IsSupportedArgument(argument)) - { - std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; - - return false; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - - std::size_t flop = 2_uz * M * N * K; - std::size_t num_btype = sizeof(KernelADataType) * M * K + sizeof(KernelBDataType) * K * N + - sizeof(KernelCDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; - - c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - if(config.do_verification) - { - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); - - ref_invoker.Run(ref_argument); - - return !ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); - } - - return true; -} - -bool run_gemm_int4_example(int argc, char* argv[]) -{ - ProblemSize problem_size; - ExecutionConfig config; - - return parse_cmd_args(argc, argv, problem_size, config) && run_gemm_int4(problem_size, config); -}