diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 2c9d2d78cd..67f6160873 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -1,4 +1,17 @@ +add_custom_target(example_grouped_gemm_xdl) + add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) + +add_dependencies(example_grouped_gemm_xdl + example_grouped_gemm_xdl_fp32 + example_grouped_gemm_xdl_fp16 + example_grouped_gemm_xdl_bfp16 + example_grouped_gemm_xdl_int8) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) +endif() diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp new file mode 100644 index 0000000000..7355641d98 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::int4_t; +using BDataType = ck::int4_t; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using DsDataType = ck::Tuple<>; +using EDataType = ck::int4_t; + +using KernelADataType = int8_t; +using KernelBDataType = int8_t; +using KernelEDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl + // clang-format off + < ALayout, //ALayout + BLayout, //BLayout + DsLayout, //DsLayout + ELayout, //ELayout + KernelADataType, //ADataType + KernelBDataType, //BDataType + AccDataType, //AccDataType + CShuffleDataType, //CShuffleDataType + DsDataType, //DsDataType + KernelEDataType, //EDataType + AElementOp, //AElementwiseOperation + BElementOp, //BElementwiseOperation + CDEElementOp, //CDEElementwiseOperation + GemmDefault, //GEMMSpecialization + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // ABlockTransfer SrcAccessOrder + 2, // ABlockTransfer SrcVectorDim + 16, // ABlockTransfer SrcScalarPerVector + 16, // ABlockTransfer DstScalarPerVector_K1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // BBlockTransfer SrcAccessOrder + 2, // BBlockTransfer SrcVectorDim + 16, // BBlockTransfer SrcScalarPerVector + 16, // BBlockTransfer DstScalarPerVector_K1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 64, 1, 4>, // CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl + 16>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +#define BUILD_INT4_EXAMPLE +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index e1a4134846..01ba4ec045 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -22,6 +22,12 @@ struct ExecutionConfig final bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) { +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); + static_assert(sizeof(ADataType) == sizeof(KernelADataType)); + static_assert(sizeof(BDataType) == sizeof(KernelBDataType)); + static_assert(sizeof(EDataType) == sizeof(KernelEDataType)); +#endif int group_count = problem_size.group_count; // GEMM shape @@ -61,7 +67,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co std::vector> a_tensors; std::vector> b_tensors; std::vector> c_host_tensors; +#ifdef BUILD_INT4_EXAMPLE + std::vector> c_device_tensors; +#else std::vector> c_device_tensors; +#endif a_tensors.reserve(group_count); b_tensors.reserve(group_count); @@ -86,9 +96,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{}))); c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); +#ifdef BUILD_INT4_EXAMPLE + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); +#else c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); - +#endif std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc << std::endl; @@ -124,8 +138,16 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co c_tensors_device.emplace_back(std::make_unique( sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSpaceSize())); +#ifdef BUILD_INT4_EXAMPLE + const Tensor a_converted(a_tensors[i]); + const Tensor b_converted(b_tensors[i]); + + a_tensors_device[i]->ToDevice(a_converted.mData.data()); + b_tensors_device[i]->ToDevice(b_converted.mData.data()); +#else a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); +#endif p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); @@ -156,14 +178,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; + invoker.Run(argument, StreamConfig{nullptr, false}); bool pass = true; if(config.do_verification) @@ -190,11 +205,28 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co c_element_op); ref_invoker.Run(ref_argument); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor c_device_result_converted(c_device_tensors[i]); + pass &= ck::utils::check_err(c_device_result_converted.mData, c_host_tensors[i].mData); + +#else pass &= ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData); +#endif } } - return pass ? 0 : 1; + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + return pass; } bool run_grouped_gemm_example(int argc, char* argv[]) @@ -208,7 +240,7 @@ bool run_grouped_gemm_example(int argc, char* argv[]) { problem_size.Ms.push_back(256 + 256 * i); problem_size.Ns.push_back(128 + 128 * i); - problem_size.Ks.push_back(64 + 64 * i); + problem_size.Ks.push_back(128 + 64 * i); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/example/22_cgemm/CMakeLists.txt b/example/22_cgemm/CMakeLists.txt index 0bad707f24..1564561156 100644 --- a/example/22_cgemm/CMakeLists.txt +++ b/example/22_cgemm/CMakeLists.txt @@ -5,7 +5,13 @@ add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp) add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) -add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16) -add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16) -add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32) -add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8) +add_dependencies(example_cgemm_xdl + example_cgemm_xdl_bf16 + example_cgemm_xdl_fp16 + example_cgemm_xdl_fp32 + example_cgemm_xdl_int8) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp) + add_dependencies(example_cgemm_xdl example_cgemm_xdl_int4) +endif() diff --git a/example/22_cgemm/cgemm_xdl_bf16.cpp b/example/22_cgemm/cgemm_xdl_bf16.cpp index 5f73c684c7..4369be8a32 100644 --- a/example/22_cgemm/cgemm_xdl_bf16.cpp +++ b/example/22_cgemm/cgemm_xdl_bf16.cpp @@ -117,16 +117,16 @@ int main(int argc, char* argv[]) exit(0); } - return run_cgemm_xdl( + return !run_cgemm_xdl( M, N, K, StrideA, StrideB, StrideC, do_verification, init_method, time_kernel); } diff --git a/example/22_cgemm/cgemm_xdl_common.hpp b/example/22_cgemm/cgemm_xdl_common.hpp index d388a6e71b..f420ac24d5 100644 --- a/example/22_cgemm/cgemm_xdl_common.hpp +++ b/example/22_cgemm/cgemm_xdl_common.hpp @@ -21,6 +21,9 @@ using F32 = float; using BF16 = ck::bhalf_t; using INT8 = std::int8_t; using INT32 = std::int32_t; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +using INT4 = ck::int4_t; +#endif template -int run_cgemm_xdl(ck::index_t M, - ck::index_t N, - ck::index_t K, - ck::index_t StrideA, - ck::index_t StrideB, - ck::index_t StrideC, - bool do_verification, - int init_method, - bool time_kernel) + typename ReferenceCGemmInstance, + typename KernelADataType = ADataType, + typename KernelBDataType = BDataType, + typename KernelCDataType = CDataType> +bool run_cgemm_xdl(ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + bool do_verification, + int init_method, + bool time_kernel) { +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + static_assert(sizeof(ck::int4_t) == sizeof(int8_t), + "sizeof ck::int4_t and int8_t is different!"); + static_assert(sizeof(ADataType) == sizeof(KernelADataType), + "sizeof ADataType and KernelADataType is different!"); + static_assert(sizeof(BDataType) == sizeof(KernelBDataType), + "sizeof BDataType and KernelBDataType is different!"); + static_assert(sizeof(CDataType) == sizeof(KernelCDataType), + "sizeof CDataType and KernelCDataType is different!"); +#endif + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(std::is_same::value) @@ -61,8 +78,10 @@ int run_cgemm_xdl(ck::index_t M, Tensor a_m_k_imag(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n_real(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor b_k_n_imag(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_real_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_imag_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_real_device_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_imag_device_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl; std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl; @@ -89,20 +108,41 @@ int run_cgemm_xdl(ck::index_t M, auto cgemm = DeviceCGemmInstance{}; - DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * a_m_k_real.mDesc.GetElementSpaceSize()); - DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * a_m_k_imag.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * b_k_n_real.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * b_k_n_imag.mDesc.GetElementSpaceSize()); - DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * + DeviceMem a_m_k_real_device_buf(sizeof(KernelADataType) * + a_m_k_real.mDesc.GetElementSpaceSize()); + DeviceMem a_m_k_imag_device_buf(sizeof(KernelADataType) * + a_m_k_imag.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_real_device_buf(sizeof(KernelBDataType) * + b_k_n_real.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_imag_device_buf(sizeof(KernelBDataType) * + b_k_n_imag.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_real_device_buf(sizeof(KernelCDataType) * c_m_n_real_device_result.mDesc.GetElementSpaceSize()); - DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * + DeviceMem c_m_n_imag_device_buf(sizeof(KernelCDataType) * c_m_n_imag_device_result.mDesc.GetElementSpaceSize()); DeviceMem workspace_device_buf(cgemm.GetWorkspaceSize(M, N, K, StrideA, StrideB, StrideC)); - a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data()); - a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data()); - b_k_n_real_device_buf.ToDevice(b_k_n_real.mData.data()); - b_k_n_imag_device_buf.ToDevice(b_k_n_imag.mData.data()); +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + if constexpr(std::is_same_v) + { + Tensor a_m_k_real_converted(a_m_k_real); + Tensor a_m_k_imag_converted(a_m_k_imag); + Tensor b_k_n_real_converted(b_k_n_real); + Tensor b_k_n_imag_converted(b_k_n_imag); + + a_m_k_real_device_buf.ToDevice(a_m_k_real_converted.mData.data()); + a_m_k_imag_device_buf.ToDevice(a_m_k_imag_converted.mData.data()); + b_k_n_real_device_buf.ToDevice(b_k_n_real_converted.mData.data()); + b_k_n_imag_device_buf.ToDevice(b_k_n_imag_converted.mData.data()); + } + else +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + { + a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data()); + a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data()); + b_k_n_real_device_buf.ToDevice(b_k_n_real.mData.data()); + b_k_n_imag_device_buf.ToDevice(b_k_n_imag.mData.data()); + } auto a_element_op = AElementwiseOperation{}; auto b_element_op = BElementwiseOperation{}; @@ -111,13 +151,13 @@ int run_cgemm_xdl(ck::index_t M, // do GEMM auto invoker = cgemm.MakeInvoker(); auto argument = - cgemm.MakeArgument(static_cast(a_m_k_real_device_buf.GetDeviceBuffer()), - static_cast(a_m_k_imag_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_real_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_imag_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_real_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_imag_device_buf.GetDeviceBuffer()), - static_cast(workspace_device_buf.GetDeviceBuffer()), + cgemm.MakeArgument(static_cast(a_m_k_real_device_buf.GetDeviceBuffer()), + static_cast(a_m_k_imag_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_real_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_imag_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_real_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_imag_device_buf.GetDeviceBuffer()), + static_cast(workspace_device_buf.GetDeviceBuffer()), M, N, K, @@ -142,16 +182,12 @@ int run_cgemm_xdl(ck::index_t M, std::size_t(2) * (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N); - float tflops = static_cast(flop) / 1.E9 / ave_time; - + float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << cgemm.GetTypeString() << std::endl; - c_m_n_real_device_buf.FromDevice(c_m_n_real_device_result.mData.data()); - c_m_n_imag_device_buf.FromDevice(c_m_n_imag_device_result.mData.data()); - if(do_verification) { Tensor c_m_n_real_host_result( @@ -159,9 +195,8 @@ int run_cgemm_xdl(ck::index_t M, Tensor c_m_n_imag_host_result( f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - auto ref_cgemm = ReferenceCGemmInstance{}; - auto ref_invoker = ref_cgemm.MakeInvoker(); - + auto ref_cgemm = ReferenceCGemmInstance{}; + auto ref_invoker = ref_cgemm.MakeInvoker(); auto ref_argument = ref_cgemm.MakeArgument(a_m_k_real, a_m_k_imag, b_k_n_real, @@ -174,19 +209,45 @@ int run_cgemm_xdl(ck::index_t M, ref_invoker.Run(ref_argument); + c_m_n_real_device_buf.FromDevice(c_m_n_real_device_result.mData.data()); + c_m_n_imag_device_buf.FromDevice(c_m_n_imag_device_result.mData.data()); + bool result = true; - result = ck::utils::check_err(c_m_n_real_device_result.mData, - c_m_n_real_host_result.mData, - "Verification error: incorrect results in real part!", - 1e-2f, - 1e-1f); - result = result && - ck::utils::check_err(c_m_n_imag_device_result.mData, - c_m_n_imag_host_result.mData, - "Verification error: incorrect results in imaginary part!", - 1e-2f, - 1e-1f); - return result ? 0 : 1; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + if constexpr(std::is_same_v) + { + const Tensor c_m_n_real_device_result_converted(c_m_n_real_device_result); + const Tensor c_m_n_imag_device_result_converted(c_m_n_imag_device_result); + + result = ck::utils::check_err(c_m_n_real_device_result_converted.mData, + c_m_n_real_host_result.mData, + "Verification error: incorrect results in real part!", + 1e-2f, + 1e-1f); + result = result && ck::utils::check_err( + c_m_n_imag_device_result_converted.mData, + c_m_n_imag_host_result.mData, + "Verification error: incorrect results in imaginary part!", + 1e-2f, + 1e-1f); + } + else +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + { + result = ck::utils::check_err(c_m_n_real_device_result.mData, + c_m_n_real_host_result.mData, + "Verification error: incorrect results in real part!", + 1e-2f, + 1e-1f); + result = result && ck::utils::check_err( + c_m_n_imag_device_result.mData, + c_m_n_imag_host_result.mData, + "Verification error: incorrect results in imaginary part!", + 1e-2f, + 1e-1f); + } + + return result; } - return 0; + return true; } diff --git a/example/22_cgemm/cgemm_xdl_fp16.cpp b/example/22_cgemm/cgemm_xdl_fp16.cpp index 7909bc1d65..a73d41e82f 100644 --- a/example/22_cgemm/cgemm_xdl_fp16.cpp +++ b/example/22_cgemm/cgemm_xdl_fp16.cpp @@ -116,16 +116,16 @@ int main(int argc, char* argv[]) exit(0); } - return run_cgemm_xdl( + return !run_cgemm_xdl( M, N, K, StrideA, StrideB, StrideC, do_verification, init_method, time_kernel); } diff --git a/example/22_cgemm/cgemm_xdl_fp32.cpp b/example/22_cgemm/cgemm_xdl_fp32.cpp index 53b6afbc89..ac32ba768d 100644 --- a/example/22_cgemm/cgemm_xdl_fp32.cpp +++ b/example/22_cgemm/cgemm_xdl_fp32.cpp @@ -117,16 +117,16 @@ int main(int argc, char* argv[]) exit(0); } - return run_cgemm_xdl( + return !run_cgemm_xdl( M, N, K, StrideA, StrideB, StrideC, do_verification, init_method, time_kernel); } diff --git a/example/22_cgemm/cgemm_xdl_int4.cpp b/example/22_cgemm/cgemm_xdl_int4.cpp new file mode 100644 index 0000000000..cf3cbbc2ac --- /dev/null +++ b/example/22_cgemm/cgemm_xdl_int4.cpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "cgemm_xdl_common.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +using ADataType = INT4; +using BDataType = INT4; +using CDataType = INT4; +using AccDataType = INT32; +using CShuffleDataType = INT32; + +using KernelADataType = INT8; +using KernelBDataType = INT8; +using KernelCDataType = INT8; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using ReferenceCGemmInstance = ck::tensor_operation::host:: + ReferenceCGemm; + +// clang-format off +using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle + , // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 16, // index_t ABlockTransferSrcScalarPerVector + 16, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 8, // index_t BBlockTransferSrcScalarPerVector + 8, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // CGEMM shape + ck::index_t M = 1024; + ck::index_t N = 1152; + ck::index_t K = 512; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideC = N; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n" + << std::endl; + exit(EXIT_SUCCESS); + } + + return !run_cgemm_xdl( + M, N, K, StrideA, StrideB, StrideC, do_verification, init_method, time_kernel); +} diff --git a/example/22_cgemm/cgemm_xdl_int8.cpp b/example/22_cgemm/cgemm_xdl_int8.cpp index be91877387..e1389ac923 100644 --- a/example/22_cgemm/cgemm_xdl_int8.cpp +++ b/example/22_cgemm/cgemm_xdl_int8.cpp @@ -117,16 +117,16 @@ int main(int argc, char* argv[]) exit(0); } - return run_cgemm_xdl( + return !run_cgemm_xdl( M, N, K, StrideA, StrideB, StrideC, do_verification, init_method, time_kernel); } diff --git a/example/24_batched_gemm/CMakeLists.txt b/example/24_batched_gemm/CMakeLists.txt index 8ca5e55dcb..7962576e87 100644 --- a/example/24_batched_gemm/CMakeLists.txt +++ b/example/24_batched_gemm/CMakeLists.txt @@ -1,4 +1,17 @@ +add_custom_target(example_batched_gemm_xdl) + add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp) add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) + +add_dependencies(example_batched_gemm_xdl + example_batched_gemm_xdl_fp32 + example_batched_gemm_xdl_fp16 + example_batched_gemm_xdl_bfp16 + example_batched_gemm_xdl_int8) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) + add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) +endif() diff --git a/example/24_batched_gemm/batched_gemm_xdl_int4.cpp b/example/24_batched_gemm/batched_gemm_xdl_int4.cpp new file mode 100644 index 0000000000..95e715efa8 --- /dev/null +++ b/example/24_batched_gemm/batched_gemm_xdl_int4.cpp @@ -0,0 +1,99 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::int4_t; +using BDataType = ck::int4_t; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using DsDataType = ck::Tuple<>; +using EDataType = ck::int4_t; + +using KernelADataType = int8_t; +using KernelBDataType = int8_t; +using KernelEDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl + // clang-format off + < ALayout, //ALayout + BLayout, //BLayout + DsLayout, //DsLayout + ELayout, //ELayout + KernelADataType, //ADataType + KernelBDataType, //BDataType + AccDataType, //AccDataType + CShuffleDataType, //CShuffleDataType + DsDataType, //DsDataType + KernelEDataType, //EDataType + AElementOp, //AElementwiseOperation + BElementOp, //BElementwiseOperation + CDEElementOp, //CDEElementwiseOperation + GemmDefault, //GEMMSpecialization + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // ABlockTransfer SrcAccessOrder + 2, // ABlockTransfer SrcVectorDim + 16, // ABlockTransfer SrcScalarPerVector + 16, // ABlockTransfer DstScalarPerVector_K1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<1, 0, 2>, // BBlockTransfer SrcAccessOrder + 2, // BBlockTransfer SrcVectorDim + 16, // BBlockTransfer SrcScalarPerVector + 16, // BBlockTransfer DstScalarPerVector_K1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 64, 1, 4>, // CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl + 16>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +#define BUILD_INT4_EXAMPLE +#include "run_batched_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/run_batched_gemm_example.inc b/example/24_batched_gemm/run_batched_gemm_example.inc index 2db6ab76be..20bef9f935 100644 --- a/example/24_batched_gemm/run_batched_gemm_example.inc +++ b/example/24_batched_gemm/run_batched_gemm_example.inc @@ -1,3 +1,5 @@ +#include + #pragma once struct ProblemSize final @@ -28,7 +30,23 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co { using namespace ck::literals; - auto& [M, N, K, stride_A, stride_B, stride_C, batch_stride_A, batch_stride_B, batch_stride_C, batch_count] = problem_size; +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); + static_assert(sizeof(ADataType) == sizeof(KernelADataType)); + static_assert(sizeof(BDataType) == sizeof(KernelBDataType)); + static_assert(sizeof(EDataType) == sizeof(KernelEDataType)); +#endif + + auto& [M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count] = problem_size; // GEMM shape auto f_host_tensor_descriptor = [](std::size_t batch_count_, @@ -53,9 +71,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{})); Tensor b_g_k_n( f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{})); - +#ifdef BUILD_INT4_EXAMPLE + Tensor e_g_m_n_device_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{})); +#else Tensor e_g_m_n_device_result( f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{})); +#endif std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; @@ -78,9 +100,16 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_device_buf(sizeof(EDataType) * e_g_m_n_device_result.mDesc.GetElementSpaceSize()); +#ifdef BUILD_INT4_EXAMPLE + const Tensor a_g_m_k_converted(a_g_m_k); + const Tensor b_g_k_n_converted(b_g_k_n); + + a_device_buf.ToDevice(a_g_m_k_converted.mData.data()); + b_device_buf.ToDevice(b_g_k_n_converted.mData.data()); +#else a_device_buf.ToDevice(a_g_m_k.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data()); - +#endif auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto cde_element_op = CDEElementOp{}; @@ -116,28 +145,21 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - - std::size_t flop = std::size_t(2) * batch_count * M * N * K; - std::size_t num_btype = sizeof(ADataType) * batch_count * M * K + - sizeof(BDataType) * batch_count * K * N + - sizeof(EDataType) * batch_count * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; - + invoker.Run(argument, StreamConfig{nullptr, false}); bool pass = true; if(config.do_verification) { c_device_buf.FromDevice(e_g_m_n_device_result.mData.data()); - using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: - ReferenceBatchedGemm; + using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; auto ref_invoker = ref_batched_gemm.MakeInvoker(); @@ -150,8 +172,29 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ref_invoker.Run(ref_argument); +#ifdef BUILD_INT4_EXAMPLE + const Tensor e_device_result_converted(e_g_m_n_device_result); + pass &= ck::utils::check_err(e_device_result_converted.mData, e_g_m_n_host_result.mData); + +#else pass = ck::utils::check_err( - e_g_m_n_host_result.mData, e_g_m_n_device_result.mData, "Error: Incorrect results c"); + e_g_m_n_device_result.mData, e_g_m_n_host_result.mData, "Error: Incorrect results c"); +#endif + } + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * batch_count * M * N * K; + std::size_t num_btype = sizeof(ADataType) * batch_count * M * K + + sizeof(BDataType) * batch_count * K * N + + sizeof(EDataType) * batch_count * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; } return pass ? 0 : 1; @@ -162,9 +205,12 @@ bool run_batched_gemm_example(int argc, char* argv[]) ProblemSize problem_size; ExecutionConfig config; - problem_size.M = 256 * (rand() % 16 + 1); - problem_size.N = 128 * (rand() % 16 + 1); - problem_size.K = 64 * (rand() % 16 + 1); + std::mt19937 gen(11939); + std::uniform_int_distribution dis(0, 15); + + problem_size.M = 256 * (dis(gen) + 1); + problem_size.N = 128 * (dis(gen) + 1); + problem_size.K = 64 * (dis(gen) + 2); problem_size.stride_A = problem_size.K; problem_size.stride_B = problem_size.K; diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index ceb20921f3..7945839546 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -1,4 +1,17 @@ +add_custom_target(example_splitK_gemm_xdl) + add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) + +add_dependencies(example_splitK_gemm_xdl + example_splitK_gemm_xdl_fp32 + example_splitK_gemm_xdl_fp16 + example_splitK_gemm_xdl_bfp16 + example_splitK_gemm_xdl_int8) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) + add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) +endif() diff --git a/example/35_splitK_gemm/run_splitK_gemm_example.inc b/example/35_splitK_gemm/run_splitK_gemm_example.inc index cbd43869dd..c78cb36a9a 100644 --- a/example/35_splitK_gemm/run_splitK_gemm_example.inc +++ b/example/35_splitK_gemm/run_splitK_gemm_example.inc @@ -24,6 +24,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con { using namespace ck::literals; +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); + static_assert(sizeof(ADataType) == sizeof(KernelADataType)); + static_assert(sizeof(BDataType) == sizeof(KernelBDataType)); +#endif + auto& [M, N, K, StrideA, StrideB, StrideC, KBatch] = problem_size; auto f_host_tensor_descriptor = @@ -42,12 +48,11 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; switch(config.init_method) { @@ -69,8 +74,16 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); +#ifdef BUILD_INT4_EXAMPLE + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); +#endif c_m_n_device_buf.SetZero(); auto a_element_op = AElementOp{}; @@ -80,19 +93,25 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con // do GEMM auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op, - KBatch); + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), +#endif + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + KBatch); if(!gemm.IsSupportedArgument(argument)) { @@ -101,23 +120,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con return 0; } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; - - c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false}); + bool pass = true; if(config.do_verification) { + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + auto ref_argument = ref_gemm.MakeArgument( a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); @@ -136,19 +146,33 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con if(std::is_same::value) { - return ck::utils::check_err(c_m_n_device_result.mData, - c_m_n_host_result.mData, - "fp16 incorrect result", - 3e-3, - 1e-3); + pass &= ck::utils::check_err(c_m_n_device_result.mData, + c_m_n_host_result.mData, + "fp16 incorrect result", + 3e-3, + 1e-3); } else { - return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + pass &= ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); } } - return true; + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + return pass; } bool run_splitK_gemm_example(int argc, char* argv[]) diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp new file mode 100644 index 0000000000..d2392faf51 --- /dev/null +++ b/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::int4_t; +using BDataType = ck::int4_t; +using AccDataType = int32_t; +using CDataType = int32_t; + +using KernelADataType = int8_t; +using KernelBDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off + , // ABlockTransfer ThreadCluster Lengths_K0_M_K1 + S<0, 2, 1, 3>, // ABlockTransfer ThreadCluster ArrangeOrder + S<0, 2, 1, 3>, // ABlockTransfer SrcAccessOrder + 3, // ABlockTransfer SrcVectorDim + 16, // ABlockTransfer SrcScalarPerVector + 16, // ABlockTransfer DstScalarPerVector_K1 + true, // ABlockLdsExtraM + S<1, 4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1 + S<0, 1, 3, 2>, // BBlockTransfer ThreadCluster ArrangeOrder + S<0, 1, 3, 2>, // BBlockTransfer SrcAccessOrder + 3, // BBlockTransfer SrcVectorDim + 16, // BBlockTransfer SrcScalarPerVector + 16, // BBlockTransfer DstScalarPerVector_K1 + true, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CBlockTransferClusterLengths _MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 4>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +#define BUILD_INT4_EXAMPLE +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index 71df7d10e5..f9d11fcd8c 100755 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,2 +1,2 @@ -#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' -git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' +#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' +git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'