From 9934c982a8222d77f79030496049d8e481c2efbc Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 18 Apr 2024 22:52:43 -0700 Subject: [PATCH] Seperate headers for GPU data types (#291) Prevent unnecessarily including data type headers in everywhere. --- include/mscclpp/gpu.hpp | 9 --------- include/mscclpp/gpu_data_types.hpp | 24 ++++++++++++++++++++++++ include/mscclpp/nvls_device.hpp | 1 + src/include/execution_kernel.hpp | 1 + 4 files changed, 26 insertions(+), 9 deletions(-) create mode 100644 include/mscclpp/gpu_data_types.hpp diff --git a/include/mscclpp/gpu.hpp b/include/mscclpp/gpu.hpp index 01f87509..b300ca47 100644 --- a/include/mscclpp/gpu.hpp +++ b/include/mscclpp/gpu.hpp @@ -6,8 +6,6 @@ #if defined(__HIP_PLATFORM_AMD__) -#include -#include #include using cudaError_t = hipError_t; @@ -92,14 +90,7 @@ constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWri #else #include -#include #include -#if (CUDART_VERSION >= 11000) -#include -#endif -#if (CUDART_VERSION >= 11080) -#include -#endif #endif diff --git a/include/mscclpp/gpu_data_types.hpp b/include/mscclpp/gpu_data_types.hpp new file mode 100644 index 00000000..224e56de --- /dev/null +++ b/include/mscclpp/gpu_data_types.hpp @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSCCLPP_GPU_DATA_TYPES_HPP_ +#define MSCCLPP_GPU_DATA_TYPES_HPP_ + +#if defined(__HIP_PLATFORM_AMD__) + +#include +#include + +#else + +#include +#if (CUDART_VERSION >= 11000) +#include +#endif +#if (CUDART_VERSION >= 11080) +#include +#endif + +#endif + +#endif // MSCCLPP_GPU_DATA_TYPES_HPP_ diff --git a/include/mscclpp/nvls_device.hpp b/include/mscclpp/nvls_device.hpp index 1b3d6bc5..b04defbc 100644 --- a/include/mscclpp/nvls_device.hpp +++ b/include/mscclpp/nvls_device.hpp @@ -5,6 +5,7 @@ #define MSCCLPP_NVLS_DEVICE_HPP_ #include +#include #include #include "device.hpp" diff --git a/src/include/execution_kernel.hpp b/src/include/execution_kernel.hpp index 08e8796a..f25f35a6 100644 --- a/src/include/execution_kernel.hpp +++ b/src/include/execution_kernel.hpp @@ -5,6 +5,7 @@ #define MSCCLPP_EXECUTION_KERNEL_HPP_ #include +#include #include #include #include