diff --git a/include/mscclpp/assert_device.hpp b/include/mscclpp/assert_device.hpp index 9d0171cb..bf982ba6 100644 --- a/include/mscclpp/assert_device.hpp +++ b/include/mscclpp/assert_device.hpp @@ -12,6 +12,8 @@ #if !defined(DEBUG_BUILD) +/// Assert a condition on the device and print a message if the condition is false. +/// This macro does nothing in a release mode build (when DEBUG_BUILD is undefined). #define MSCCLPP_ASSERT_DEVICE(__cond, __msg) #else // defined(DEBUG_BUILD) @@ -24,6 +26,8 @@ extern "C" __host__ __device__ void __assert_fail(const char *__assertion, const const char *__function) __THROW; #endif // !defined(MSCCLPP_DEVICE_HIP) +/// Assert a condition on the device and print a message if the condition is false. +/// This macro does nothing in a release mode build (when DEBUG_BUILD is undefined). #define MSCCLPP_ASSERT_DEVICE(__cond, __msg) \ do { \ if (!(__cond)) { \ diff --git a/include/mscclpp/concurrency_device.hpp b/include/mscclpp/concurrency_device.hpp index a75e69e0..240cc032 100644 --- a/include/mscclpp/concurrency_device.hpp +++ b/include/mscclpp/concurrency_device.hpp @@ -7,11 +7,24 @@ #include "atomic_device.hpp" #include "poll_device.hpp" -#define NUM_DEVICE_SYNCER_COUNTER 3 - namespace mscclpp { /// A device-wide barrier. +/// This barrier can be used to synchronize multiple thread blocks within a kernel. +/// It uses atomic operations to ensure that all threads in the same kernel reach the barrier before proceeding +/// and they can safely read data written by other threads in different blocks. +/// +/// Example usage of DeviceSyncer: +/// ```cpp +/// __global__ void myKernel(mscclpp::DeviceSyncer* syncer, int numBlocks) { +/// // Do some work here +/// // ... +/// // Synchronize all blocks +/// syncer->sync(numBlocks); +/// // All blocks have reached this point +/// // ... +/// } +/// ``` struct DeviceSyncer { public: /// Construct a new DeviceSyncer object. @@ -20,6 +33,9 @@ struct DeviceSyncer { /// Destroy the DeviceSyncer object. MSCCLPP_INLINE ~DeviceSyncer() = default; + /// The number of sync counters. + static const unsigned int NumCounters = 3U; + #if defined(MSCCLPP_DEVICE_COMPILE) /// Synchronize all threads inside a kernel. Guarantee that all previous work of all threads in cooperating blocks is /// finished. @@ -30,14 +46,14 @@ struct DeviceSyncer { __syncthreads(); if (blockNum == 1) return; if (threadIdx.x == 0) { - unsigned int tmp = (preFlag_ + 1) % NUM_DEVICE_SYNCER_COUNTER; - unsigned int next = (tmp + 1) % NUM_DEVICE_SYNCER_COUNTER; - unsigned int* count = &count_[tmp]; - count_[next] = 0; + unsigned int countIdx = (currentCountIdx_ + 1) % NumCounters; + unsigned int nextCountIdx = (countIdx + 1) % NumCounters; + unsigned int* count = &count_[countIdx]; + count_[nextCountIdx] = 0; atomicFetchAdd(count, 1U, memoryOrderRelease); POLL_MAYBE_JAILBREAK((atomicLoad(count, memoryOrderAcquire) != targetCnt), maxSpinCount); - preFlag_ = tmp; + currentCountIdx_ = countIdx; } // We need sync here because only a single thread is checking whether // the flag is flipped. @@ -47,11 +63,37 @@ struct DeviceSyncer { private: /// The counter of synchronized blocks. - unsigned int count_[NUM_DEVICE_SYNCER_COUNTER]; - /// The flag to indicate whether to increase or decrease @ref flag_. - unsigned int preFlag_; + unsigned int count_[NumCounters]; + /// Index of the current counter being used. + unsigned int currentCountIdx_; }; +/// A device-wide semaphore. +/// This semaphore can be used to control access to a resource across multiple threads or blocks. +/// It uses atomic operations to ensure that the semaphore value is updated correctly across threads. +/// The semaphore value is an integer that can be set, acquired, and released. +/// +/// Example usage of DeviceSemaphore: +/// ```cpp +/// __global__ void myKernel(mscclpp::DeviceSemaphore* semaphore) { +/// // Initialize the semaphore (allow up to 3 threads access the resource simultaneously) +/// if (blockIdx.x == 0 && threadIdx.x == 0) { +/// semaphore->set(3); +/// } +/// // Each block acquires the semaphore before accessing the shared resource +/// if (threadIdx.x == 0) { +/// semaphore->acquire(); +/// } +/// __syncthreads(); +/// // Access the shared resource +/// // ... +/// __syncthreads(); +/// // Release the semaphore after accessing the shared resource +/// if (threadIdx.x == 0) { +/// semaphore->release(); +/// } +/// } +/// ``` struct DeviceSemaphore { public: /// Construct a new DeviceSemaphore object. @@ -61,11 +103,13 @@ struct DeviceSemaphore { MSCCLPP_INLINE ~DeviceSemaphore() = default; #if defined(MSCCLPP_DEVICE_COMPILE) - /// set the semaphore value. + /// Set the semaphore value. This function is used to initialize or reset the semaphore value. + /// The initial value is typically set to a positive integer to allow acquiring the semaphore. /// @param value The value to set. MSCCLPP_DEVICE_INLINE void set(int value) { atomicStore(&semaphore_, value, memoryOrderRelease); } - /// Acquire the semaphore. + /// Acquire the semaphore (wait until the semaphore value is greater than 0 and decrease it by 1). + /// This function will spin until the semaphore is acquired or the maximum spin count is reached. /// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative. MSCCLPP_DEVICE_INLINE void acquire(int maxSpinCount = -1) { if (atomicFetchAdd(&semaphore_, -1, memoryOrderAcquire) <= 0) { @@ -73,7 +117,7 @@ struct DeviceSemaphore { } } - /// Release the semaphore. + /// Release the semaphore (increase the semaphore value by 1). MSCCLPP_DEVICE_INLINE void release() { atomicFetchAdd(&semaphore_, 1, memoryOrderRelease); } #endif // defined(MSCCLPP_DEVICE_COMPILE) diff --git a/include/mscclpp/copy_device.hpp b/include/mscclpp/copy_device.hpp index d8c29a2c..55a2a782 100644 --- a/include/mscclpp/copy_device.hpp +++ b/include/mscclpp/copy_device.hpp @@ -23,11 +23,12 @@ namespace detail { /// This function is intended to be collectively called by multiple threads. Each thread copies a part of /// elements. /// +/// @tparam T The type of the elements to be copied. /// @param dst The destination address. /// @param src The source address. /// @param numElems The number of elements to be copied. -/// @param threadId The index of the current thread among all threads running this function. This is different -/// from the `threadIdx` in CUDA. +/// @param threadId The index of the current thread among all threads running this function. +/// Should be less than @p numThreads. /// @param numThreads The total number of threads that run this function. /// template @@ -43,7 +44,27 @@ MSCCLPP_DEVICE_INLINE void copy(T* dst, T* src, uint64_t numElems, uint32_t thre } // namespace detail -/// this is a helper for copy function +/// Helper function of mscclpp::copy(). Copy data from the source memory to the destination memory. +/// +/// This function is intended to be collectively called by multiple threads. Each thread copies a part of +/// elements. +/// +/// @note The source and destination addresses do not have to be aligned to the size of @p T, but the misalignment +/// to the size of @p T should be multiple of 4 bytes and should be the same for both source and destination addresses. +/// The behavior of this function is undefined otherwise. +/// @note The number of bytes to be copied should be a multiple of 4 bytes. If the number of bytes is not a multiple +/// of 4 bytes, the remainder bytes will not be copied. +/// +/// @tparam T The type of the elements to be copied. +/// @tparam CopyRemainder If false, the function will not copy data that is unaligned to the size of @p T. If true, +/// the function will try to copy the unaligned data with conditions (see the notes). +/// @param dst The destination address. +/// @param src The source address. +/// @param bytes Bytes of the data to be copied. Should be a multiple of 4 bytes. +/// @param threadId The index of the current thread among all threads running this function. +/// Should be less than @p numThreads. +/// @param numThreads The total number of threads that run this function. +/// template MSCCLPP_DEVICE_INLINE void copyHelper(void* dst, void* src, uint64_t bytes, uint32_t threadId, uint32_t numThreads) { int* dstInt = reinterpret_cast(dst); @@ -70,17 +91,27 @@ MSCCLPP_DEVICE_INLINE void copyHelper(void* dst, void* src, uint64_t bytes, uint } } -/// Copy aligned data from the source memory to the destination memory. +/// Copy data from the source memory to the destination memory. /// -/// This function is a warpper of Element::copy(). Unlike Element::copy(), this function can copy remainder -/// bytes when @p CopyRemainder is true. Still, the 16. -/// @tparam CopyRemainder Whether to copy remainder bytes when the number of bytes is not a multiple of @p -/// Alignment. -/// @param dst The destination address. Should be aligned to @p Alignment in the same way as @p src. -/// @param src The source address. Should be aligned to @p Alignment in the same way as @p dst. -/// @param bytes Bytes of the data to be copied. Should be a multiple of @p Alignment. -/// @param threadId The index of the current thread among all threads running this function. This is different from -/// the `threadIdx` in CUDA. +/// This function is intended to be collectively called by multiple threads. Each thread copies a part of +/// elements. +/// +/// @note The source and destination addresses do not have to be aligned to the @p Alignment value, +/// but the misalignment to @p Alignment should be multiple of 4 bytes and should be the same for both source +/// and destination addresses. +/// The behavior of this function is undefined otherwise. +/// @note The number of bytes to be copied should be a multiple of 4 bytes. If the number of bytes is not a multiple +/// of 4 bytes, the remainder bytes will not be copied. +/// +/// @tparam Alignment The alignment of the data to be copied. A larger alignment value is more likely to achieve higher +/// copying throughput. Should be one of 4, 8, or 16. +/// @tparam CopyRemainder If false, the function will not copy data that is unaligned to the @p Alignment value. +/// If true, the function will try to copy the unaligned data with conditions (see the notes). +/// @param dst The destination address. +/// @param src The source address. +/// @param bytes Bytes of the data to be copied. Should be a multiple of 4 bytes. +/// @param threadId The index of the current thread among all threads running this function. +/// Should be less than @p numThreads. /// @param numThreads The total number of threads that run this function. /// template @@ -98,13 +129,17 @@ MSCCLPP_DEVICE_INLINE void copy(void* dst, void* src, uint64_t bytes, uint32_t t /// Read data from the origin and write packets to the target buffer. /// +/// This function is intended to be collectively called by multiple threads. Each thread copies a part of +/// packets. +/// +/// @tparam PacketType The packet type. It should be either LL16Packet or LL8Packet. /// @param targetPtr The target buffer. /// @param originPtr The origin buffer. /// @param originBytes The number of bytes to write to the target buffer. -/// @param threadId The thread ID. The thread ID should be less than @p numThreads. -/// @param numThreads The number of threads that call this function. -/// @param flag The flag to write. -/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet. +/// @param threadId The index of the current thread among all threads running this function. +/// Should be less than @p numThreads. +/// @param numThreads The total number of threads that run this function. +/// @param flag The flag to write in the packets. /// template MSCCLPP_DEVICE_INLINE void copyToPackets(void* targetPtr, const void* originPtr, uint64_t originBytes, @@ -138,13 +173,17 @@ MSCCLPP_DEVICE_INLINE void copyToPackets(void* targetPtr, const void* /// Read packets from the target buffer and write retrieved data to the origin. /// -/// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet. +/// This function is intended to be collectively called by multiple threads. Each thread reads a part of +/// packets. +/// +/// @tparam PacketType The packet type. It should be either LL16Packet or LL8Packet. /// @param originPtr The origin buffer. /// @param targetPtr The target buffer. /// @param originBytes The number of bytes to read from the origin buffer. -/// @param threadId The thread ID. The thread ID should be less than @p numThreads. -/// @param numThreads The number of threads that call this function. -/// @param flag The flag to read. +/// @param threadId The index of the current thread among all threads running this function. +/// Should be less than @p numThreads. +/// @param numThreads The total number of threads that run this function. +/// @param flag The flag to write in the packets. /// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative. /// template diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 57fad56d..d4a8254c 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -25,23 +25,86 @@ namespace mscclpp { using UniqueId = std::array; /// Return a version string. +/// @return A string representing the version of MSCCL++ in the format "major.minor.patch". std::string version(); /// Base class for bootstraps. class Bootstrap { public: + /// Constructor. Bootstrap(){}; + + /// Destructor. virtual ~Bootstrap() = default; - virtual int getRank() = 0; - virtual int getNranks() = 0; - virtual int getNranksPerNode() = 0; + + /// Return the rank of the process. + /// @return The rank of the process. + virtual int getRank() const = 0; + + /// Return the total number of ranks. + /// @return The total number of ranks. + virtual int getNranks() const = 0; + + /// Return the total number of ranks per node. + /// @return The total number of ranks per node. + virtual int getNranksPerNode() const = 0; + + /// Send arbitrary data to another process. + /// + /// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size, + /// senderRank, tag)`. Multiple calls to send() with the same @p peer and @p tag will be ordered by + /// the order of calls, corresponding to the order of recv() calls on the receiving side. In cases where + /// the execution order of multiple send()s or recv()s between two ranks is unknown, they should be differentiated + /// by using different @p tag values to prevent unexpected behavior. + /// + /// @param data The data to send. + /// @param size The size of the data to send. + /// @param peer The rank of the process to send the data to. + /// @param tag The tag to send the data with. virtual void send(void* data, int size, int peer, int tag) = 0; + + /// Receive data sent from another process by send(). + /// + /// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size, + /// senderRank, tag)`. Multiple calls to send() with the same @p peer and @p tag will be ordered by + /// the order of calls, corresponding to the order of recv() calls on the receiving side. In cases where + /// the execution order of multiple send()s or recv()s between two ranks is unknown, they should be differentiated + /// by using different @p tag values to prevent unexpected behavior. + /// + /// @param data The buffer to write the received data to. + /// @param size The size of the data to receive. + /// @param peer The rank of the process to receive the data from. + /// @param tag The tag to receive the data with. virtual void recv(void* data, int size, int peer, int tag) = 0; + + /// Gather data from all processes. + /// + /// When called by rank `r`, this sends data from `allData[r * size]` to `allData[(r + 1) * size - 1]` to all other + /// ranks. The data sent by rank `r` is received into `allData[r * size]` of other ranks. + /// + /// @param allData The buffer to write the received data to. + /// @param size The size of the data each rank sends. virtual void allGather(void* allData, int size) = 0; + + /// Synchronize all processes. virtual void barrier() = 0; + /// A partial barrier that synchronizes a group of ranks. + /// @param ranks The ranks to synchronize. void groupBarrier(const std::vector& ranks); + + /// Wrapper of send() that sends a vector of characters. + /// @param data The data to send. + /// @param peer The rank of the process to send the data to. + /// @param tag The tag to send the data with. void send(const std::vector& data, int peer, int tag); + + /// Wrapper of recv() that receives a vector of characters. + /// @param data The buffer to write the received data to. + /// @param peer The rank of the process to receive the data from. + /// @param tag The tag to receive the data with. + /// + /// @note The data vector will be resized to the size of the received data. void recv(std::vector& data, int peer, int tag); }; @@ -60,33 +123,37 @@ class TcpBootstrap : public Bootstrap { /// Destructor. ~TcpBootstrap(); - /// Return the unique ID stored in the @ref TcpBootstrap. - /// @return The unique ID stored in the @ref TcpBootstrap. + /// Return the unique ID stored in the TcpBootstrap. + /// @return The unique ID stored in the TcpBootstrap. UniqueId getUniqueId() const; - /// Initialize the @ref TcpBootstrap with a given unique ID. - /// @param uniqueId The unique ID to initialize the @ref TcpBootstrap with. + /// Initialize the TcpBootstrap with a given unique ID. The unique ID can be generated by any methods; + /// it can be created by createUniqueId() or can be any arbitrary bit arrays provided by the user. + /// @param uniqueId The unique ID to initialize the TcpBootstrap with. /// @param timeoutSec The connection timeout in seconds. void initialize(UniqueId uniqueId, int64_t timeoutSec = 30); - /// Initialize the @ref TcpBootstrap with a string formatted as "ip:port" or "interface:ip:port". + /// Initialize the TcpBootstrap with a string formatted as "ip:port" or "interface:ip:port". /// @param ifIpPortTrio The string formatted as "ip:port" or "interface:ip:port". /// @param timeoutSec The connection timeout in seconds. void initialize(const std::string& ifIpPortTrio, int64_t timeoutSec = 30); /// Return the rank of the process. - int getRank() override; + int getRank() const override; /// Return the total number of ranks. - int getNranks() override; + int getNranks() const override; /// Return the total number of ranks per node. - int getNranksPerNode() override; + int getNranksPerNode() const override; - /// Send data to another process. + /// Send arbitrary data to another process. /// /// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size, - /// senderRank, tag)`. + /// senderRank, tag)`. Multiple calls to send() with the same @p peer and @p tag will be ordered by + /// the order of calls, corresponding to the order of recv() calls on the receiving side. In cases where + /// the execution order of multiple send()s or recv()s between two ranks is unknown, they should be differentiated + /// by using different @p tag values to prevent unexpected behavior. /// /// @param data The data to send. /// @param size The size of the data to send. @@ -94,10 +161,13 @@ class TcpBootstrap : public Bootstrap { /// @param tag The tag to send the data with. void send(void* data, int size, int peer, int tag) override; - /// Receive data from another process. + /// Receive data sent from another process by send(). /// /// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size, - /// senderRank, tag)`. + /// senderRank, tag)`. Multiple calls to send() with the same @p peer and @p tag will be ordered by + /// the order of calls, corresponding to the order of recv() calls on the receiving side. In cases where + /// the execution order of multiple send()s or recv()s between two ranks is unknown, they should be differentiated + /// by using different @p tag values to prevent unexpected behavior. /// /// @param data The buffer to write the received data to. /// @param size The size of the data to receive. @@ -129,10 +199,7 @@ class TcpBootstrap : public Bootstrap { void barrier() override; private: - // The interal implementation. class Impl; - - // Pointer to the internal implementation. std::unique_ptr pimpl_; }; @@ -314,7 +381,9 @@ inline TransportFlags operator^(Transport transport1, Transport transport2) { class Context; class Connection; -/// Represents a block of memory that has been registered to a @ref Context. +/// Represents a block of memory that has been registered to a Context. +/// RegisteredMemory does not own the memory it points to, but it provides a way to transfer metadata about the memory +/// to other processes, hence allowing their access to the memory block. class RegisteredMemory { public: /// Default constructor. @@ -336,17 +405,17 @@ class RegisteredMemory { /// Get the size of the memory block. /// /// @return The size of the memory block. - size_t size(); + size_t size() const; /// Get the transport flags associated with the memory block. /// /// @return The transport flags associated with the memory block. - TransportFlags transports(); + TransportFlags transports() const; /// Serialize the RegisteredMemory object to a vector of characters. /// /// @return A vector of characters representing the serialized RegisteredMemory object. - std::vector serialize(); + std::vector serialize() const; /// Deserialize a RegisteredMemory object from a vector of characters. /// @@ -355,13 +424,8 @@ class RegisteredMemory { static RegisteredMemory deserialize(const std::vector& data); private: - // The interal implementation. struct Impl; - - // Internal constructor. RegisteredMemory(std::shared_ptr pimpl); - - // Pointer to the internal implementation. A shared_ptr is used since RegisteredMemory is immutable. std::shared_ptr pimpl_; friend class Context; @@ -396,13 +460,8 @@ class Endpoint { static Endpoint deserialize(const std::vector& data); private: - // The interal implementation. struct Impl; - - // Internal constructor. Endpoint(std::shared_ptr pimpl); - - // Pointer to the internal implementation. A shared_ptr is used since Endpoint is immutable. std::shared_ptr pimpl_; friend class Context; @@ -418,20 +477,20 @@ class Connection { virtual ~Connection() = default; - /// Write data from a source @ref RegisteredMemory to a destination @ref RegisteredMemory. + /// Write data from a source RegisteredMemory to a destination RegisteredMemory. /// - /// @param dst The destination @ref RegisteredMemory. - /// @param dstOffset The offset in bytes from the start of the destination @ref RegisteredMemory. - /// @param src The source @ref RegisteredMemory. - /// @param srcOffset The offset in bytes from the start of the source @ref RegisteredMemory. + /// @param dst The destination RegisteredMemory. + /// @param dstOffset The offset in bytes from the start of the destination RegisteredMemory. + /// @param src The source RegisteredMemory. + /// @param srcOffset The offset in bytes from the start of the source RegisteredMemory. /// @param size The number of bytes to write. virtual void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) = 0; - /// Update a 8-byte value in a destination @ref RegisteredMemory and synchronize the change with the remote process. + /// Update a 8-byte value in a destination RegisteredMemory and synchronize the change with the remote process. /// - /// @param dst The destination @ref RegisteredMemory. - /// @param dstOffset The offset in bytes from the start of the destination @ref RegisteredMemory. + /// @param dst The destination RegisteredMemory. + /// @param dstOffset The offset in bytes from the start of the destination RegisteredMemory. /// @param src A pointer to the value to update. /// @param newValue The new value to write. virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) = 0; @@ -442,22 +501,22 @@ class Connection { /// Get the transport used by the local process. /// /// @return The transport used by the local process. - virtual Transport transport() = 0; + virtual Transport transport() const = 0; /// Get the transport used by the remote process. /// /// @return The transport used by the remote process. - virtual Transport remoteTransport() = 0; + virtual Transport remoteTransport() const = 0; /// Get the name of the transport used for this connection /// - /// @return name of @ref transport() -> @ref remoteTransport() - std::string getTransportName(); + /// @return A string formatted as "localTransportName -> remoteTransportName". + std::string getTransportName() const; /// Get the maximum write queue size /// /// @return The maximum number of write requests that can be queued. - int getMaxWriteQueueSize(); + int getMaxWriteQueueSize() const; protected: // Internal methods for getting implementation pointers. @@ -500,19 +559,19 @@ struct EndpointConfig { }; /// Represents a context for communication. This provides a low-level interface for forming connections in use-cases -/// where the process group abstraction offered by @ref Communicator is not suitable, e.g., ephemeral client-server +/// where the process group abstraction offered by Communicator is not suitable, e.g., ephemeral client-server /// connections. Correct use of this class requires external synchronization when finalizing connections with the -/// @ref connect() method. +/// connect() method. /// /// As an example, a client-server scenario where the server will write to the client might proceed as follows: -/// 1. The client creates an endpoint with @ref createEndpoint() and sends it to the server. -/// 2. The server receives the client endpoint, creates its own endpoint with @ref createEndpoint(), sends it to the -/// client, and creates a connection with @ref connect(). -/// 4. The client receives the server endpoint, creates a connection with @ref connect() and sends a -/// @ref RegisteredMemory to the server. -/// 5. The server receives the @ref RegisteredMemory and writes to it using the previously created connection. -/// The client waiting to create a connection before sending the @ref RegisteredMemory ensures that the server can not -/// write to the @ref RegisteredMemory before the connection is established. +/// 1. The client creates an endpoint with createEndpoint() and sends it to the server. +/// 2. The server receives the client endpoint, creates its own endpoint with createEndpoint(), sends it to the +/// client, and creates a connection with connect(). +/// 4. The client receives the server endpoint, creates a connection with connect() and sends a +/// RegisteredMemory to the server. +/// 5. The server receives the RegisteredMemory and writes to it using the previously created connection. +/// The client waiting to create a connection before sending the RegisteredMemory ensures that the server can not +/// write to the RegisteredMemory before the connection is established. /// /// While some transports may have more relaxed implementation behavior, this should not be relied upon. class Context { @@ -528,7 +587,7 @@ class Context { /// @param ptr Base pointer to the memory. /// @param size Size of the memory region in bytes. /// @param transports Transport flags. - /// @return RegisteredMemory A handle to the buffer. + /// @return A RegisteredMemory object representing the registered memory region. RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports); /// Create an endpoint for establishing connections. @@ -543,14 +602,11 @@ class Context { /// /// @param localEndpoint The local endpoint. /// @param remoteEndpoint The remote endpoint. - /// @return std::shared_ptr A shared pointer to the connection. + /// @return A shared pointer to the connection. std::shared_ptr connect(Endpoint localEndpoint, Endpoint remoteEndpoint); private: - // The interal implementation. struct Impl; - - // Pointer to the internal implementation. std::unique_ptr pimpl_; friend class RegisteredMemory; @@ -571,6 +627,52 @@ using NonblockingFuture [[deprecated("Use std::shared_future instead. This will /// 4. Call get() on all futures returned by connect() and recvMemory(). /// 5. All done; use connections and registered memories to build channels. /// +/// Correct Example 1: +/// ```cpp +/// // Rank 0 +/// communicator.sendMemory(memory1, 1, tag); +/// communicator.sendMemory(memory2, 1, tag); +/// auto connection = communicator.connect(1, tag, Transport::CudaIpc); +/// connection.get(); // This will return the connection. +/// // Rank 1 +/// auto mem1 = communicator.recvMemory(0, tag); +/// auto mem2 = communicator.recvMemory(0, tag); +/// auto connection = communicator.connect(0, tag, Transport::CudaIpc); +/// mem2.get(); // This will return memory2. +/// connection.get(); // This will return the connection. +/// mem1.get(); // This will return memory1. +/// ``` +/// +/// Correct Example 2: +/// ```cpp +/// // Rank 0 +/// communicator.sendMemory(memory0, 1, tag); +/// auto mem1 = communicator.recvMemory(1, tag); +/// auto connection = communicator.connect(1, tag, Transport::CudaIpc); +/// connection.get(); // This will return the connection. +/// mem1.get(); // This will return memory1. +/// // Rank 1 +/// auto mem0 = communicator.recvMemory(0, tag); +/// communicator.sendMemory(memory1, 0, tag); +/// auto connection = communicator.connect(0, tag, Transport::CudaIpc); +/// mem0.get(); // This will return memory0. +/// connection.get(); // This will return the connection. +/// ``` +/// +/// Wrong Example: +/// ```cpp +/// // Rank 0 +/// communicator.sendMemory(memory0, 1, tag); +/// auto mem1 = communicator.recvMemory(1, tag); +/// auto connection = communicator.connect(1, tag, Transport::CudaIpc); +/// // Rank 1 +/// auto mem0 = communicator.recvMemory(0, tag); +/// auto connection = communicator.connect(0, tag, Transport::CudaIpc); // undefined behavior +/// communicator.sendMemory(memory1, 0, tag); +/// ``` +/// In the wrong example, the connection information from rank 1 will be sent to `mem1` object on rank 0, +/// where the object type is RegisteredMemory, not Connection. +/// class Communicator { public: /// Initializes the communicator with a given bootstrap implementation. @@ -584,12 +686,12 @@ class Communicator { /// Returns the bootstrap held by this communicator. /// - /// @return std::shared_ptr The bootstrap held by this communicator. + /// @return The bootstrap held by this communicator. std::shared_ptr bootstrap(); /// Returns the context held by this communicator. /// - /// @return std::shared_ptr The context held by this communicator. + /// @return The context held by this communicator. std::shared_ptr context(); /// Register a region of GPU memory for use in this communicator's context. @@ -597,16 +699,24 @@ class Communicator { /// @param ptr Base pointer to the memory. /// @param size Size of the memory region in bytes. /// @param transports Transport flags. - /// @return RegisteredMemory A handle to the buffer. + /// @return A RegisteredMemory object representing the registered memory region. RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports); /// Send information of a registered memory to the remote side. /// - /// The send will be performed immediately upon calling this function. + /// The send will be started upon calling this function, but this function returns immediately without + /// waiting for the completion of the send. RegisteredMemory sent via `sendMemory(memory, remoteRank, tag)` can be + /// received via `recvMemory(remoteRank, tag)`. + /// + /// Multiple calls to either sendMemory() or connect() with the same @p remoteRank and @p tag will be ordered by + /// the order of calls, corresponding to the order of recvMemory() or connect() calls on the receiving side. + /// In cases where the execution order is unknown between two ranks, they should be differentiated by using + /// different @p tag values to prevent unexpected behavior. /// /// @param memory The registered memory buffer to send information about. /// @param remoteRank The rank of the remote process. /// @param tag The tag to use for identifying the send. + /// void sendMemory(RegisteredMemory memory, int remoteRank, int tag); [[deprecated("Use sendMemory() instead. This will be removed in a future release.")]] void sendMemoryOnSetup( @@ -617,11 +727,25 @@ class Communicator { /// Receive memory information from a corresponding sendMemory call on the remote side. /// /// This function returns a future immediately. The actual receive will be performed upon calling - /// the first get() on the future. + /// the first get() on the future. RegisteredMemory sent via `sendMemory(memory, remoteRank, tag)` can be + /// received via `recvMemory(remoteRank, tag)`. + /// + /// Multiple calls to either sendMemory() or connect() with the same @p remoteRank and @p tag will be ordered by + /// the order of calls, corresponding to the order of recvMemory() or connect() calls on the receiving side. + /// In cases where the execution order is unknown between two ranks, they should be differentiated by using + /// different @p tag values to prevent unexpected behavior. + /// + /// @note To guarantee the receiving order, calling get() on a future returned by recvMemory() or connect() + /// may start receiving other RegisteredMemory or Connection objects of which futures were returned by + /// an earlier call to recvMemory() or connect() with the same @p remoteRank and @p tag. For example, if + /// we call recvMemory() or connect() five times with the same @p remoteRank and @p tag and then call get() + /// on the last future, it will start receiving the five RegisteredMemory or Connection objects in order, + /// back to back. /// /// @param remoteRank The rank of the remote process. /// @param tag The tag to use for identifying the receive. - /// @return std::shared_future A non-blocking future of registered memory. + /// @return A future of registered memory. + /// std::shared_future recvMemory(int remoteRank, int tag); [[deprecated( @@ -632,18 +756,32 @@ class Communicator { /// Connect to a remote rank. /// - /// This function will immediately send metadata about the local endpoint to the remote rank, and return a future - /// without waiting for the remote rank to respond. The connection will be established when the remote rank - /// responds with its own endpoint and the local rank calls the first get() on the future. + /// This function will start (but not be waiting for) sending metadata about the local endpoint to the remote rank, + /// and return a future connection without waiting for the remote rank to respond. + /// The connection will be established when the remote rank responds with its own endpoint and the local rank calls + /// the first get() on the future. /// Note that this function is two-way and a connection from rank `i` to remote rank `j` needs /// to have a counterpart from rank `j` to rank `i`. Note that with IB, buffers are registered at a page level and if /// a buffer is spread through multiple pages and do not fully utilize all of them, IB's QP has to register for all /// involved pages. This potentially has security risks if the connection's accesses are given to a malicious process. /// + /// Multiple calls to either sendMemory() or connect() with the same @p remoteRank and @p tag will be ordered by + /// the order of calls, corresponding to the order of recvMemory() or connect() calls on the receiving side. + /// In cases where the execution order is unknown between two ranks, they should be differentiated by using + /// different @p tag values to prevent unexpected behavior. + /// + /// @note To guarantee the receiving order, calling get() on a future returned by recvMemory() or connect() + /// may start receiving other RegisteredMemory or Connection objects of which futures were returned by + /// an earlier call to recvMemory() or connect() with the same @p remoteRank and @p tag. For example, if + /// we call recvMemory() or connect() five times with the same @p remoteRank and @p tag and then call get() + /// on the last future, it will start receiving the five RegisteredMemory or Connection objects in order, + /// back to back. + /// /// @param remoteRank The rank of the remote process. - /// @param tag The tag of the connection for identifying it. - /// @param config The configuration for the local endpoint. - /// @return std::shared_future> A non-blocking future of shared pointer to the connection. + /// @param tag The tag to use for identifying the send and receive. + /// @param localConfig The configuration for the local endpoint. + /// @return A future of shared pointer to the connection. + /// std::shared_future> connect(int remoteRank, int tag, EndpointConfig localConfig); [[deprecated("Use connect() instead. This will be removed in a future release.")]] NonblockingFuture< @@ -667,10 +805,7 @@ class Communicator { [[deprecated("setup() is now no-op and no longer needed. This will be removed in a future release.")]] void setup() {} private: - // The interal implementation. struct Impl; - - // Pointer to the internal implementation. std::unique_ptr pimpl_; }; diff --git a/include/mscclpp/env.hpp b/include/mscclpp/env.hpp index 1e165d93..147b90ed 100644 --- a/include/mscclpp/env.hpp +++ b/include/mscclpp/env.hpp @@ -16,24 +16,78 @@ class Env; std::shared_ptr env(); /// The MSCCL++ environment. The constructor reads environment variables and sets the corresponding fields. -/// Use the @ref env() function to get the environment object. +/// Use the env() function to get the environment object. class Env { public: + /// Env name: `MSCCLPP_DEBUG`. The debug flag, one of VERSION, WARN, INFO, ABORT, or TRACE. Unset by default. const std::string debug; + + /// Env name: `MSCCLPP_DEBUG_SUBSYS`. The debug subsystem, a comma-separated list of subsystems to enable + /// debug logging for. + /// If the first character is '^', it inverts the mask, i.e., enables all subsystems except those specified. + /// Possible values are INIT, COLL, P2P, SHM, NET, GRAPH, TUNING, ENV, ALLOC, CALL, MSCCLPP_EXECUTOR, ALL. + /// Unset by default. const std::string debugSubsys; + + /// Env name: `MSCCLPP_DEBUG_FILE`. A file path to write debug logs to. Unset by default. const std::string debugFile; + + /// Env name: `MSCCLPP_HCA_DEVICES`. A comma-separated list of HCA devices to use for IB transport. i-th device + /// in the list will be used for the i-th GPU in the system. If unset, it will use ibverbs APIs to find the + /// devices automatically. const std::string hcaDevices; + + /// Env name: `MSCCLPP_HOSTID`. A string that uniquely identifies the host. If unset, it will use the hostname. + /// This is used to determine whether the host is the same across different processes. const std::string hostid; + + /// Env name: `MSCCLPP_SOCKET_FAMILY`. The socket family to use for TCP sockets (used by TcpBootstrap and + /// the Ethernet transport). Possible values are `AF_INET` (IPv4) and `AF_INET6` (IPv6). + /// If unset, it will not force any family and will use the first one found. const std::string socketFamily; + + /// Env name: `MSCCLPP_SOCKET_IFNAME`. The interface name to use for TCP sockets (used by TcpBootstrap and + /// the Ethernet transport). If unset, it will use the first interface found that matches the socket family. const std::string socketIfname; + + /// Env name: `MSCCLPP_COMM_ID`. To be deprecated; don't use this. const std::string commId; + + /// Env name: `MSCCLPP_EXECUTION_PLAN_DIR`. The directory to find execution plans from. This should be set to + /// use execution plans for the NCCL API. Unset by default. const std::string executionPlanDir; + + /// Env name: `MSCCLPP_NPKIT_DUMP_DIR`. The directory to dump NPKIT traces to. If this is set, NPKIT will be + /// enabled and will dump traces to this directory. Unset by default. const std::string npkitDumpDir; + + /// Env name: `MSCCLPP_CUDAIPC_USE_DEFAULT_STREAM`. If set to true, the CUDA IPC transport will use the default + /// stream for all operations. If set to false, it will use a separate stream for each operation. This is an + /// experimental feature and should be false in most cases. Default is false. const bool cudaIpcUseDefaultStream; + + /// Env name: `MSCCLPP_NCCL_LIB_PATH`. The path to the original NCCL/RCCL shared library. If set, it will be used + /// as a fallback for NCCL operations in cases where the MSCCL++ NCCL cannot work. const std::string ncclSharedLibPath; + + /// Env name: `MSCCLPP_FORCE_NCCL_FALLBACK_OPERATION`. A comma-separated list of NCCL operations that should + /// always use the fallback implementation, even if the MSCCL++ NCCL can handle them. This is useful for + /// debugging purposes. Currently supports `all`, `broadcast`, `allreduce`, `reducescatter`, and `allgather`. const std::string forceNcclFallbackOperation; + + /// Env name: `MSCCLPP_ENABLE_NCCL_FALLBACK`. If set to true, it will enable the fallback implementation for NCCL + /// operations. This is useful for debugging purposes. Default is false. const bool enableNcclFallback; + + /// Env name: `MSCCLPP_DISABLE_CHANNEL_CACHE`. If set to true, it will disable the channel cache for NCCL APIs. + /// Currently, this should be set to true if the application may call NCCL APIs on the same local buffer with + /// different remote buffers, e.g., in the case of a dynamic communicator. If CUDA/HIP graphs are used, disabling + /// the channel cache won't affect the performance, but otherwise it may lead to performance degradation. + /// Default is false. const bool disableChannelCache; + + /// Env name: `MSCCLPP_FORCE_DISABLE_NVLS`. If set to true, it will disable the NVLS support in MSCCL++. + /// Default is false. const bool forceDisableNvls; private: diff --git a/include/mscclpp/errors.hpp b/include/mscclpp/errors.hpp index 8d3fde4d..de9bc2dd 100644 --- a/include/mscclpp/errors.hpp +++ b/include/mscclpp/errors.hpp @@ -28,13 +28,13 @@ std::string errorToString(enum ErrorCode error); /// Base class for all errors thrown by MSCCL++. class BaseError : public std::runtime_error { public: - /// Constructor for @ref BaseError. + /// Constructor of BaseError. /// /// @param message The error message. /// @param errorCode The error code. BaseError(const std::string& message, int errorCode); - /// Constructor for @ref BaseError. + /// Constructor of BaseError. /// /// @param errorCode The error code. explicit BaseError(int errorCode); diff --git a/include/mscclpp/fifo.hpp b/include/mscclpp/fifo.hpp index 126f901e..5f841e22 100644 --- a/include/mscclpp/fifo.hpp +++ b/include/mscclpp/fifo.hpp @@ -17,16 +17,16 @@ constexpr size_t DEFAULT_FIFO_SIZE = 128; /// A class representing a host proxy FIFO that can consume work elements pushed by device threads. class Fifo { public: - /// Constructs a new @ref Fifo object. + /// Constructs a new Fifo object. /// @param size The number of entires in the FIFO. Fifo(int size = DEFAULT_FIFO_SIZE); - /// Destroys the @ref Fifo object. + /// Destroys the Fifo object. ~Fifo(); /// Polls the FIFO for a trigger. /// - /// Returns @ref ProxyTrigger which is the trigger at the head of fifo. + /// Returns ProxyTrigger which is the trigger at the head of fifo. ProxyTrigger poll(); /// Pops a trigger from the FIFO. @@ -41,10 +41,10 @@ class Fifo { /// @return The FIFO size. int size() const; - /// Returns a @ref FifoDeviceHandle object representing the device FIFO. + /// Returns a FifoDeviceHandle object representing the device FIFO. /// - /// @return A @ref FifoDeviceHandle object representing the device FIFO. - FifoDeviceHandle deviceHandle(); + /// @return A FifoDeviceHandle object representing the device FIFO. + FifoDeviceHandle deviceHandle() const; private: struct Impl; diff --git a/include/mscclpp/fifo_device.hpp b/include/mscclpp/fifo_device.hpp index f431b1d8..faef6470 100644 --- a/include/mscclpp/fifo_device.hpp +++ b/include/mscclpp/fifo_device.hpp @@ -20,7 +20,7 @@ namespace mscclpp { /// This struct is used as a work element in the concurrent FIFO where multiple device threads can push /// ProxyTrigger elements and a single host proxy thread consumes these work elements. /// -/// Do not use the most significant bit of @ref snd as it is reserved for memory consistency purposes +/// Do not use the most significant bit of snd as it is reserved for memory consistency purposes. struct alignas(16) ProxyTrigger { uint64_t fst, snd; }; @@ -29,11 +29,11 @@ struct alignas(16) ProxyTrigger { /// work elements and a single host proxy thread consumes them. /// /// The FIFO has a head pointer allocated on the device which starts at 0 and goes up to 2^64-1, which is almost -/// infinity. There are two copies of the tail, one on the device, @ref FifoDeviceHandle::tailReplica, and another on +/// infinity. There are two copies of the tail, one on the device, FifoDeviceHandle::tailReplica, and another on /// the host, namely, hostTail. The host always has the "true" tail and occasionally pushes it to the copy on the /// device. Therefore, most of the time, the device has a stale version. The invariants are: tailReplica <= hostTail <= -/// head. The @ref push() function increments head, hostTail is updated in @ref Fifo::pop(), and it occasionally flushes -/// it to tailReplica via @ref Fifo::flushTail(). +/// head. The push() function increments head, hostTail is updated in Fifo::pop(), and it occasionally flushes +/// it to tailReplica via Fifo::flushTail(). /// /// Duplicating the tail is a good idea because the FIFO is large enough, and we do not need frequent updates for the /// tail as there is usually enough space for device threads to push their work into. diff --git a/include/mscclpp/gpu_utils.hpp b/include/mscclpp/gpu_utils.hpp index 72755007..595149d7 100644 --- a/include/mscclpp/gpu_utils.hpp +++ b/include/mscclpp/gpu_utils.hpp @@ -11,7 +11,7 @@ #include "gpu.hpp" #include "utils.hpp" -/// Throw @ref mscclpp::CudaError if @p cmd does not return cudaSuccess. +/// Throw mscclpp::CudaError if @p cmd does not return cudaSuccess. /// @param cmd The command to execute. #define MSCCLPP_CUDATHROW(cmd) \ do { \ @@ -22,7 +22,7 @@ } \ } while (false) -/// Throw @ref mscclpp::CuError if @p cmd does not return CUDA_SUCCESS. +/// Throw mscclpp::CuError if @p cmd does not return CUDA_SUCCESS. /// @param cmd The command to execute. #define MSCCLPP_CUTHROW(cmd) \ do { \ @@ -36,7 +36,7 @@ namespace mscclpp { /// A RAII guard that will cudaThreadExchangeStreamCaptureMode to cudaStreamCaptureModeRelaxed on construction and -/// restore the previous mode on destruction. This is helpful when we want to avoid CUDA graph capture. +/// restore the previous mode on destruction. This is helpful when we want to avoid CUDA/HIP graph capture. struct AvoidCudaGraphCaptureGuard { AvoidCudaGraphCaptureGuard(); ~AvoidCudaGraphCaptureGuard(); @@ -45,12 +45,29 @@ struct AvoidCudaGraphCaptureGuard { /// A RAII wrapper around cudaStream_t that will call cudaStreamDestroy on destruction. struct CudaStreamWithFlags { + /// Constructor without flags. This will not create any stream. set() can be called later to create a stream with + /// specified flags. CudaStreamWithFlags() : stream_(nullptr) {} + + /// Constructor with flags. This will create a stream with the specified flags on the current device. + /// @param flags The flags to create the stream with. CudaStreamWithFlags(unsigned int flags); + + /// Destructor. This will destroy the stream if it was created. ~CudaStreamWithFlags(); + + /// Set the stream with the specified flags. If the stream was already created, it will raise an error with + /// ErrorCode::InvalidUsage. + /// @param flags The flags to create the stream with. + /// @throws Error if the stream was already created. void set(unsigned int flags); + + /// Check if the stream is empty (not created). + /// @return true if the stream is empty, false otherwise. bool empty() const; + operator cudaStream_t() const { return stream_; } + cudaStream_t stream_; }; @@ -188,11 +205,24 @@ size_t getMulticastGranularity(size_t size, CUmulticastGranularity_flags granFla } // namespace detail +/// Copies memory from src to dst asynchronously. +/// @tparam T Type of each element in the memory. +/// @param dst Destination address. +/// @param src Source address. +/// @param nelems Number of elements to copy. +/// @param stream The stream to use for the copy operation. +/// @param kind The kind of copy operation. Default is cudaMemcpyDefault. template void gpuMemcpyAsync(T* dst, const T* src, size_t nelems, cudaStream_t stream, cudaMemcpyKind kind = cudaMemcpyDefault) { detail::gpuMemcpyAsync(dst, src, nelems * sizeof(T), stream, kind); } +/// Copies memory from src to dst synchronously. +/// @tparam T Type of each element in the memory. +/// @param dst Destination address. +/// @param src Source address. +/// @param nelems Number of elements to copy. +/// @param kind The kind of copy operation. Default is cudaMemcpyDefault. template void gpuMemcpy(T* dst, const T* src, size_t nelems, cudaMemcpyKind kind = cudaMemcpyDefault) { detail::gpuMemcpy(dst, src, nelems * sizeof(T), kind); @@ -203,10 +233,10 @@ void gpuMemcpy(T* dst, const T* src, size_t nelems, cudaMemcpyKind kind = cudaMe /// @return True if NVLink SHARP (NVLS) is supported, false otherwise. bool isNvlsSupported(); -/// Check if ptr is allocaed by cuMemMap +/// Check if ptr is allocaed by cuMemMap. /// @param ptr The pointer to check. /// @return True if the pointer is allocated by cuMemMap, false otherwise. -bool isCuMemMapAllocated([[maybe_unused]] void* ptr); +bool isCuMemMapAllocated(void* ptr); /// Allocates a GPU memory space specialized for communication. The memory is zeroed out. Get the device pointer by /// `GpuBuffer::data()`. diff --git a/include/mscclpp/memory_channel.hpp b/include/mscclpp/memory_channel.hpp index f5b44985..b685e93e 100644 --- a/include/mscclpp/memory_channel.hpp +++ b/include/mscclpp/memory_channel.hpp @@ -29,7 +29,7 @@ struct BaseMemoryChannel { BaseMemoryChannel& operator=(BaseMemoryChannel& other) = default; - /// Device-side handle for @ref BaseMemoryChannel. + /// Device-side handle for BaseMemoryChannel. using DeviceHandle = BaseMemoryChannelDeviceHandle; /// Returns the device-side handle. @@ -59,7 +59,7 @@ struct MemoryChannel : public BaseMemoryChannel { MemoryChannel(std::shared_ptr semaphore, RegisteredMemory dst, void* src, void* packetBuffer = nullptr); - /// Device-side handle for @ref MemoryChannel. + /// Device-side handle for MemoryChannel. using DeviceHandle = MemoryChannelDeviceHandle; /// Returns the device-side handle. @@ -69,7 +69,7 @@ struct MemoryChannel : public BaseMemoryChannel { DeviceHandle deviceHandle() const; }; -/// @deprecated Use @ref MemoryChannel instead. +/// @deprecated Use MemoryChannel instead. [[deprecated("Use MemoryChannel instead.")]] typedef MemoryChannel SmChannel; } // namespace mscclpp diff --git a/include/mscclpp/memory_channel_device.hpp b/include/mscclpp/memory_channel_device.hpp index 64077d1a..2d9451bb 100644 --- a/include/mscclpp/memory_channel_device.hpp +++ b/include/mscclpp/memory_channel_device.hpp @@ -148,7 +148,7 @@ struct MemoryChannelDeviceHandle : public BaseMemoryChannelDeviceHandle { /// /// This function is intended to be collectively called by multiple threads. Each thread copies a part of data. /// - /// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet. + /// @tparam PacketType The packet type. It should be either LL16Packet or LL8Packet. /// @param targetOffset The offset in bytes of the remote address. /// @param originOffset The offset in bytes of the local address. /// @param originBytes Bytes of the origin to be copied. @@ -175,7 +175,7 @@ struct MemoryChannelDeviceHandle : public BaseMemoryChannelDeviceHandle { /// Retrieve data from a packet in the local packet buffer. /// - /// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet. + /// @tparam PacketType The packet type. It should be either LL16Packet or LL8Packet. /// @param index The index of the packet to be read. The offset in bytes is calculated as index * sizeof(PacketType). /// @param flag The flag to read. /// @param maxSpinCount The maximum number of spins before asserting. Never assert if negative. @@ -191,7 +191,7 @@ struct MemoryChannelDeviceHandle : public BaseMemoryChannelDeviceHandle { /// /// This function is intended to be collectively called by multiple threads. Each thread copies a part of data. /// - /// @tparam PacketType The packet type. It should be either @ref LL16Packet or @ref LL8Packet. + /// @tparam PacketType The packet type. It should be either LL16Packet or LL8Packet. /// @param targetOffset The offset in bytes of the local packet buffer. /// @param originOffset The offset in bytes of the local address. /// @param originBytes Bytes of the origin to be copied. @@ -229,7 +229,7 @@ struct MemoryChannelDeviceHandle : public BaseMemoryChannelDeviceHandle { #endif // defined(MSCCLPP_DEVICE_COMPILE) }; -/// @deprecated Use @ref MemoryChannelDeviceHandle instead. +/// @deprecated Use MemoryChannelDeviceHandle instead. [[deprecated("Use MemoryChannelDeviceHandle instead.")]] typedef MemoryChannelDeviceHandle SmChannelDeviceHandle; } // namespace mscclpp diff --git a/include/mscclpp/numa.hpp b/include/mscclpp/numa.hpp index 64eddb17..e8f2ab44 100644 --- a/include/mscclpp/numa.hpp +++ b/include/mscclpp/numa.hpp @@ -6,7 +6,15 @@ namespace mscclpp { -int getDeviceNumaNode(int cudaDev); +/// Return the NUMA node ID of the given GPU device ID. +/// @param deviceId The GPU device ID. +/// @return The NUMA node ID of the device. +/// @throw Error if the device ID is invalid or if the NUMA node cannot be determined. +int getDeviceNumaNode(int deviceId); + +/// NUMA bind the current thread to the specified NUMA node. +/// @param node The NUMA node ID to bind to. +/// @throw Error if the given NUMA node ID is invalid. void numaBind(int node); } // namespace mscclpp diff --git a/include/mscclpp/nvls.hpp b/include/mscclpp/nvls.hpp index 25d5d7f1..6fc77584 100644 --- a/include/mscclpp/nvls.hpp +++ b/include/mscclpp/nvls.hpp @@ -36,8 +36,8 @@ class NvlsConnection { friend class NvlsConnection; }; - /// @brief bind the memory allocated via @ref mscclpp::GpuBuffer to the multicast handle. The behavior - /// is undefined if the devicePtr is not allocated by @ref mscclpp::GpuBuffer. + /// @brief bind the memory allocated via mscclpp::GpuBuffer to the multicast handle. The behavior + /// is undefined if the devicePtr is not allocated by mscclpp::GpuBuffer. /// @param devicePtr The device pointer returned by `mscclpp::GpuBuffer::data()`. /// @param size The bytes of the memory to bind to the multicast handle. /// @return DeviceMulticastPointer with devicePtr, mcPtr and bufferSize diff --git a/include/mscclpp/nvls_device.hpp b/include/mscclpp/nvls_device.hpp index 622a1a59..8b7de21b 100644 --- a/include/mscclpp/nvls_device.hpp +++ b/include/mscclpp/nvls_device.hpp @@ -20,7 +20,7 @@ namespace mscclpp { template constexpr bool dependentFalse = false; // workaround before CWG2518/P2593R1 -/// Device-side handle for @ref Host2DeviceSemaphore. +/// Device-side handle for Host2DeviceSemaphore. struct DeviceMulticastPointerDeviceHandle { void* devicePtr; void* mcPtr; diff --git a/include/mscclpp/packet_device.hpp b/include/mscclpp/packet_device.hpp index 6f3e5f85..f7ff0b8b 100644 --- a/include/mscclpp/packet_device.hpp +++ b/include/mscclpp/packet_device.hpp @@ -15,7 +15,7 @@ #endif // defined(MSCCLPP_DEVICE_COMPILE) namespace mscclpp { -/// LL (low latency) protocol packet. +/// LL (low latency) protocol packet with 8 bytes of data and 8 bytes of flags. union alignas(16) LL16Packet { // Assume data is written with an atomicity of 8 bytes (IB/RDMA). struct { @@ -59,7 +59,7 @@ union alignas(16) LL16Packet { /// @param flag The flag to write. MSCCLPP_DEVICE_INLINE void write(uint2 val, uint32_t flag) { write(val.x, val.y, flag); } - /// Helper of @ref read(). + /// Helper of read(). /// @param flag The flag to read. /// @param data The 8-byte data read. /// @return True if the flag is not equal to the given flag. @@ -81,7 +81,7 @@ union alignas(16) LL16Packet { #endif } - /// Read 8 bytes of data from the packet. + /// Read 8 bytes of data from the packet. It will spin until the flag is equal to the given flag. /// @param flag The flag to read. /// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative. /// @return The 8-byte data read. @@ -96,6 +96,7 @@ union alignas(16) LL16Packet { #endif // defined(MSCCLPP_DEVICE_COMPILE) }; +/// LL (low latency) protocol packet with 4 bytes of data and 4 bytes of flags. union alignas(8) LL8Packet { // Assume data is written with an atomicity of 8 bytes (IB/RDMA). struct { @@ -111,6 +112,9 @@ union alignas(8) LL8Packet { MSCCLPP_DEVICE_INLINE LL8Packet(uint32_t val, uint32_t flag) : data(val), flag(flag) {} + /// Write 4 bytes of data to the packet. + /// @param val The 4-byte data to write. + /// @param flag The flag to write. MSCCLPP_DEVICE_INLINE void write(uint32_t val, uint32_t flag) { #if defined(MSCCLPP_DEVICE_CUDA) asm volatile("st.volatile.global.v2.u32 [%0], {%1,%2};" ::"l"(&raw_), "r"(val), "r"(flag)); @@ -121,6 +125,10 @@ union alignas(8) LL8Packet { #endif } + /// Helper of read(). + /// @param flag The flag to read. + /// @param data The 4-byte data read. + /// @return True if the flag is not equal to the given flag. MSCCLPP_DEVICE_INLINE bool readOnce(uint32_t flag, uint32_t& data) const { #if defined(MSCCLPP_DEVICE_CUDA) uint32_t f; @@ -135,6 +143,10 @@ union alignas(8) LL8Packet { #endif } + /// Read 4 bytes of data from the packet. It will spin until the flag is equal to the given flag. + /// @param flag The flag to read. + /// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative. + /// @return The 4-byte data read. MSCCLPP_DEVICE_INLINE uint32_t read(uint32_t flag, int64_t maxSpinCount = 1000000) const { uint32_t data; POLL_MAYBE_JAILBREAK(readOnce(flag, data), maxSpinCount); diff --git a/include/mscclpp/port_channel.hpp b/include/mscclpp/port_channel.hpp index 3d5a6284..dfcdf0cd 100644 --- a/include/mscclpp/port_channel.hpp +++ b/include/mscclpp/port_channel.hpp @@ -27,6 +27,7 @@ class BaseProxyService { class ProxyService : public BaseProxyService { public: /// Constructor. + /// @param fifoSize The size of the FIFO used by the proxy service. Default is DEFAULT_FIFO_SIZE. ProxyService(size_t fifoSize = DEFAULT_FIFO_SIZE); /// Build and add a semaphore to the proxy service. @@ -98,7 +99,7 @@ struct BasePortChannel { BasePortChannel& operator=(BasePortChannel& other) = default; - /// Device-side handle for @ref BasePortChannel. + /// Device-side handle for BasePortChannel. using DeviceHandle = BasePortChannelDeviceHandle; /// Returns the device-side handle. @@ -133,7 +134,7 @@ struct PortChannel : public BasePortChannel { /// Assignment operator. PortChannel& operator=(PortChannel& other) = default; - /// Device-side handle for @ref PortChannel. + /// Device-side handle for PortChannel. using DeviceHandle = PortChannelDeviceHandle; /// Returns the device-side handle. @@ -143,10 +144,10 @@ struct PortChannel : public BasePortChannel { DeviceHandle deviceHandle() const; }; -/// @deprecated Use @ref BasePortChannel instead. +/// @deprecated Use BasePortChannel instead. [[deprecated("Use BasePortChannel instead.")]] typedef BasePortChannel BaseProxyChannel; -/// @deprecated Use @ref PortChannel instead. +/// @deprecated Use PortChannel instead. [[deprecated("Use PortChannel instead.")]] typedef PortChannel ProxyChannel; } // namespace mscclpp diff --git a/include/mscclpp/port_channel_device.hpp b/include/mscclpp/port_channel_device.hpp index 8e1fcd06..eff29643 100644 --- a/include/mscclpp/port_channel_device.hpp +++ b/include/mscclpp/port_channel_device.hpp @@ -9,9 +9,11 @@ namespace mscclpp { +/// Numeric ID of Semaphore. ProxyService has an internal array indexed by these handles mapping to the +/// actual semaphores. using SemaphoreId = uint32_t; -/// Numeric ID of @ref RegisteredMemory. @ref ProxyService has an internal array indexed by these handles mapping to the +/// Numeric ID of RegisteredMemory. ProxyService has an internal array indexed by these handles mapping to the /// actual. using MemoryId = uint32_t; @@ -108,7 +110,7 @@ struct BasePortChannelDeviceHandle { : semaphoreId_(semaphoreId), semaphore_(semaphore), fifo_(fifo) {} #if defined(MSCCLPP_DEVICE_COMPILE) - /// Push a @ref TriggerData to the FIFO. + /// Push a TriggerData to the FIFO. /// @param dst The destination memory region. /// @param dstOffset The offset into the destination memory region. /// @param src The source memory region. @@ -118,7 +120,7 @@ struct BasePortChannelDeviceHandle { fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, size, semaphoreId_).value); } - /// Push a @ref TriggerData to the FIFO. + /// Push a TriggerData to the FIFO. /// @param dst The destination memory region. /// @param src The source memory region. /// @param offset The common offset into the destination and source memory regions. @@ -127,10 +129,10 @@ struct BasePortChannelDeviceHandle { put(dst, offset, src, offset, size); } - /// Push a @ref TriggerFlag to the FIFO. + /// Push a TriggerFlag to the FIFO. MSCCLPP_DEVICE_INLINE void signal() { fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value); } - /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. + /// Push a TriggerData and a TriggerFlag at the same time to the FIFO. /// @param dst The destination memory region. /// @param dstOffset The offset into the destination memory region. /// @param src The source memory region. @@ -141,7 +143,7 @@ struct BasePortChannelDeviceHandle { fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, semaphoreId_).value); } - /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. + /// Push a TriggerData and a TriggerFlag at the same time to the FIFO. /// @param dst The destination memory region. /// @param src The source memory region. /// @param offset The common offset into the destination and source memory regions. @@ -150,7 +152,7 @@ struct BasePortChannelDeviceHandle { putWithSignal(dst, offset, src, offset, size); } - /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. + /// Push a TriggerData, a TriggerFlag, and a TriggerSync at the same time to the FIFO. /// @param dst The destination memory region. /// @param dstOffset The offset into the destination memory region. /// @param src The source memory region. @@ -165,7 +167,7 @@ struct BasePortChannelDeviceHandle { fifo_.sync(curFifoHead, maxSpinCount); } - /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. + /// Push a TriggerData, a TriggerFlag, and a TriggerSync at the same time to the FIFO. /// @param dst The destination memory region. /// @param src The source memory region. /// @param offset The common offset into the destination and source memory regions. @@ -176,7 +178,7 @@ struct BasePortChannelDeviceHandle { putWithSignalAndFlush(dst, offset, src, offset, size, maxSpinCount); } - /// Push a @ref TriggerSync to the FIFO. + /// Push a TriggerSync to the FIFO. /// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative. MSCCLPP_DEVICE_INLINE void flush(int64_t maxSpinCount = 1000000) { uint64_t curFifoHead = fifo_.push(ChannelTrigger(TriggerSync, 0, 0, 0, 0, 1, semaphoreId_).value); @@ -206,7 +208,7 @@ struct PortChannelDeviceHandle : public BasePortChannelDeviceHandle { : BasePortChannelDeviceHandle(semaphoreId, semaphore, fifo), dst_(dst), src_(src) {} #if defined(MSCCLPP_DEVICE_COMPILE) - /// Push a @ref TriggerData to the FIFO. + /// Push a TriggerData to the FIFO. /// @param dstOffset The offset into the destination memory region. /// @param srcOffset The offset into the source memory region. /// @param size The size of the transfer. @@ -214,12 +216,12 @@ struct PortChannelDeviceHandle : public BasePortChannelDeviceHandle { BasePortChannelDeviceHandle::put(dst_, dstOffset, src_, srcOffset, size); } - /// Push a @ref TriggerData to the FIFO. + /// Push a TriggerData to the FIFO. /// @param offset The common offset into the destination and source memory regions. /// @param size The size of the transfer. MSCCLPP_DEVICE_INLINE void put(uint64_t offset, uint64_t size) { put(offset, offset, size); } - /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. + /// Push a TriggerData and a TriggerFlag at the same time to the FIFO. /// @param dstOffset The offset into the destination memory region. /// @param srcOffset The offset into the source memory region. /// @param size The size of the transfer. @@ -227,12 +229,12 @@ struct PortChannelDeviceHandle : public BasePortChannelDeviceHandle { BasePortChannelDeviceHandle::putWithSignal(dst_, dstOffset, src_, srcOffset, size); } - /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. + /// Push a TriggerData and a TriggerFlag at the same time to the FIFO. /// @param offset The common offset into the destination and source memory regions. /// @param size The size of the transfer. MSCCLPP_DEVICE_INLINE void putWithSignal(uint64_t offset, uint64_t size) { putWithSignal(offset, offset, size); } - /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. + /// Push a TriggerData, a TriggerFlag, and a TriggerSync at the same time to the FIFO. /// @param dstOffset The offset into the destination memory region. /// @param srcOffset The offset into the source memory region. /// @param size The size of the transfer. @@ -242,7 +244,7 @@ struct PortChannelDeviceHandle : public BasePortChannelDeviceHandle { BasePortChannelDeviceHandle::putWithSignalAndFlush(dst_, dstOffset, src_, srcOffset, size, maxSpinCount); } - /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. + /// Push a TriggerData, a TriggerFlag, and a TriggerSync at the same time to the FIFO. /// @param offset The common offset into the destination and source memory regions. /// @param size The size of the transfer. MSCCLPP_DEVICE_INLINE void putWithSignalAndFlush(uint64_t offset, uint64_t size) { diff --git a/include/mscclpp/proxy.hpp b/include/mscclpp/proxy.hpp index b55f84e3..27a5e93a 100644 --- a/include/mscclpp/proxy.hpp +++ b/include/mscclpp/proxy.hpp @@ -11,27 +11,48 @@ namespace mscclpp { +/// Possible return values of a ProxyHandler. enum class ProxyHandlerResult { + /// Move to the next trigger in the FIFO. Continue, + /// Flush the FIFO and continue to the next trigger. FlushFifoTailAndContinue, + /// Stop the proxy and exit. Stop, }; class Proxy; + +/// Type of handler function for the proxy. using ProxyHandler = std::function; +/// Host-side proxy for PortChannels. class Proxy { public: + /// Constructor of Proxy. + /// @param handler The handler function to be called for each trigger in the FIFO. + /// @param threadInit Optional function to be called in the proxy thread before starting the FIFO consumption. + /// @param fifoSize The size of the FIFO. Default is DEFAULT_FIFO_SIZE. Proxy(ProxyHandler handler, std::function threadInit, size_t fifoSize = DEFAULT_FIFO_SIZE); + + /// Constructor of Proxy. + /// @param handler The handler function to be called for each trigger in the FIFO. + /// @param fifoSize The size of the FIFO. Default is DEFAULT_FIFO_SIZE. Proxy(ProxyHandler handler, size_t fifoSize = DEFAULT_FIFO_SIZE); + + /// Destructor of Proxy. + /// This will stop the proxy if it is running. ~Proxy(); + /// Start the proxy. void start(); + + /// Stop the proxy. void stop(); /// This is a concurrent fifo which is multiple threads from the device /// can produce for and the sole proxy thread consumes it. - /// @return the fifo + /// @return A reference to the FIFO object used by the proxy. Fifo& fifo(); private: diff --git a/include/mscclpp/semaphore.hpp b/include/mscclpp/semaphore.hpp index 49e2687a..b3af7cd0 100644 --- a/include/mscclpp/semaphore.hpp +++ b/include/mscclpp/semaphore.hpp @@ -14,7 +14,7 @@ namespace mscclpp { /// A base class for semaphores. /// -/// An semaphore is a synchronization mechanism that allows the local peer to wait for the remote peer to complete a +/// A semaphore is a synchronization mechanism that allows the local peer to wait for the remote peer to complete a /// data transfer. The local peer signals the remote peer that it has completed a data transfer by incrementing the /// outbound semaphore ID. The incremented outbound semaphore ID is copied to the remote peer's inbound semaphore ID so /// that the remote peer can wait for the local peer to complete a data transfer. Vice versa, the remote peer signals @@ -22,9 +22,9 @@ namespace mscclpp { /// copying the incremented value to the local peer's inbound semaphore ID. /// /// @tparam InboundDeleter The deleter for inbound semaphore IDs. This is either `std::default_delete` for host memory -/// or @ref CudaDeleter for device memory. +/// or CudaDeleter for device memory. /// @tparam OutboundDeleter The deleter for outbound semaphore IDs. This is either `std::default_delete` for host memory -/// or @ref CudaDeleter for device memory. +/// or CudaDeleter for device memory. /// template