diff --git a/include/mscclpp/gpu.hpp b/include/mscclpp/gpu.hpp index 01f87509..b300ca47 100644 --- a/include/mscclpp/gpu.hpp +++ b/include/mscclpp/gpu.hpp @@ -6,8 +6,6 @@ #if defined(__HIP_PLATFORM_AMD__) -#include -#include #include using cudaError_t = hipError_t; @@ -92,14 +90,7 @@ constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWri #else #include -#include #include -#if (CUDART_VERSION >= 11000) -#include -#endif -#if (CUDART_VERSION >= 11080) -#include -#endif #endif diff --git a/include/mscclpp/gpu_data_types.hpp b/include/mscclpp/gpu_data_types.hpp new file mode 100644 index 00000000..224e56de --- /dev/null +++ b/include/mscclpp/gpu_data_types.hpp @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef MSCCLPP_GPU_DATA_TYPES_HPP_ +#define MSCCLPP_GPU_DATA_TYPES_HPP_ + +#if defined(__HIP_PLATFORM_AMD__) + +#include +#include + +#else + +#include +#if (CUDART_VERSION >= 11000) +#include +#endif +#if (CUDART_VERSION >= 11080) +#include +#endif + +#endif + +#endif // MSCCLPP_GPU_DATA_TYPES_HPP_ diff --git a/include/mscclpp/nvls_device.hpp b/include/mscclpp/nvls_device.hpp index 1b3d6bc5..b04defbc 100644 --- a/include/mscclpp/nvls_device.hpp +++ b/include/mscclpp/nvls_device.hpp @@ -5,6 +5,7 @@ #define MSCCLPP_NVLS_DEVICE_HPP_ #include +#include #include #include "device.hpp" diff --git a/src/include/execution_kernel.hpp b/src/include/execution_kernel.hpp index 08e8796a..f25f35a6 100644 --- a/src/include/execution_kernel.hpp +++ b/src/include/execution_kernel.hpp @@ -5,6 +5,7 @@ #define MSCCLPP_EXECUTION_KERNEL_HPP_ #include +#include #include #include #include