diff --git a/include/mscclpp/proxy_channel.hpp b/include/mscclpp/proxy_channel.hpp index 831d04ba..67479224 100644 --- a/include/mscclpp/proxy_channel.hpp +++ b/include/mscclpp/proxy_channel.hpp @@ -29,20 +29,23 @@ class ProxyService : public BaseProxyService { ProxyService(); /// Build and add a semaphore to the proxy service. + /// @param communicator The communicator for bootstrapping. /// @param connection The connection associated with the semaphore. /// @return The ID of the semaphore. SemaphoreId buildAndAddSemaphore(Communicator& communicator, std::shared_ptr connection); + /// Build and add a semaphore with pitch to the proxy service. This is used for 2D transfers. + /// @param communicator The communicator for bootstrapping. + /// @param connection The connection associated with the channel. + /// @param pitch The pitch pair. + SemaphoreId buildAndAddSemaphore(Communicator& communicator, std::shared_ptr connection, + std::pair pitch); + /// Add a semaphore to the proxy service. /// @param semaphore The semaphore to be added /// @return The ID of the semaphore. SemaphoreId addSemaphore(std::shared_ptr semaphore); - /// Add a 2D channel to the proxy service. - /// @param connection The connection associated with the channel. - /// @param pitch The pitch pair. - SemaphoreId add2DChannel(std::shared_ptr connection, std::pair pitch); - /// Register a memory region with the proxy service. /// @param memory The memory region to register. /// @return The ID of the memory region. diff --git a/include/mscclpp/proxy_channel_device.hpp b/include/mscclpp/proxy_channel_device.hpp index db90eac7..23b696a7 100644 --- a/include/mscclpp/proxy_channel_device.hpp +++ b/include/mscclpp/proxy_channel_device.hpp @@ -27,6 +27,10 @@ const TriggerType TriggerSync = 0x4; // Trigger a flush. #define MSCCLPP_BITS_CONNID 10 #define MSCCLPP_BITS_FIFO_RESERVED 1 +#define MSCCLPP_BITS_WIDTH_SIZE 16 +#define MSCCLPP_BITS_HEIGHT_SIZE 16 +#define MSCCLPP_2D_FLAG 1 + /// Basic structure of each work element in the FIFO. union ChannelTrigger { ProxyTrigger value; @@ -47,6 +51,25 @@ union ChannelTrigger { uint64_t reserved : MSCCLPP_BITS_FIFO_RESERVED; } fields; + struct { + // First 64 bits: value[0] + uint64_t width : MSCCLPP_BITS_WIDTH_SIZE; + uint64_t height : MSCCLPP_BITS_HEIGHT_SIZE; + uint64_t srcOffset : MSCCLPP_BITS_OFFSET; + uint64_t + : (64 - MSCCLPP_BITS_WIDTH_SIZE - MSCCLPP_BITS_HEIGHT_SIZE - MSCCLPP_BITS_OFFSET); // ensure 64-bit alignment + // Second 64 bits: value[1] + uint64_t dstOffset : MSCCLPP_BITS_OFFSET; + uint64_t srcMemoryId : MSCCLPP_BITS_REGMEM_HANDLE; + uint64_t dstMemoryId : MSCCLPP_BITS_REGMEM_HANDLE; + uint64_t type : MSCCLPP_BITS_TYPE; + uint64_t chanId : MSCCLPP_BITS_CONNID; + uint64_t multiDimensionFlag : MSCCLPP_2D_FLAG; + uint64_t : (64 - MSCCLPP_BITS_OFFSET - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_REGMEM_HANDLE - MSCCLPP_BITS_TYPE - + MSCCLPP_BITS_CONNID - MSCCLPP_2D_FLAG - MSCCLPP_BITS_FIFO_RESERVED); // ensure 64-bit alignment + uint64_t reserved : MSCCLPP_BITS_FIFO_RESERVED; + } fields2D; + #ifdef __CUDACC__ /// Default constructor. __forceinline__ __device__ ChannelTrigger() {} @@ -71,6 +94,27 @@ union ChannelTrigger { << MSCCLPP_BITS_OFFSET) + dstOffset); } + + /// Constructor. + /// @param type The type of the trigger. + /// @param dst The destination memory region. + /// @param dstOffset The offset into the destination memory region. + /// @param src The source memory region. + /// @param srcOffset The offset into the source memory region. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + /// @param semaphoreId The ID of the semaphore. + __device__ ChannelTrigger(TriggerType type, MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, + uint64_t width, uint64_t height, int semaphoreId) { + value.fst = (((srcOffset << MSCCLPP_BITS_HEIGHT_SIZE) + height) << MSCCLPP_BITS_WIDTH_SIZE) + width; + value.snd = ((((((((((1ULL << MSCCLPP_BITS_CONNID) + semaphoreId) << MSCCLPP_BITS_TYPE) + type) + << MSCCLPP_BITS_REGMEM_HANDLE) + + dst) + << MSCCLPP_BITS_REGMEM_HANDLE) + + src) + << MSCCLPP_BITS_OFFSET) + + dstOffset); + } #endif // __CUDACC__ }; @@ -104,6 +148,28 @@ struct ProxyChannelDeviceHandle { put(dst, offset, src, offset, size); } + /// @brief Push a @ref TriggerData to the FIFO. + /// @param dst The destination memory region. + /// @param dstOffset The offset into the destination memory region. + /// @param src The source memory region. + /// @param srcOffset The offset into the source memory region. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + __forceinline__ __device__ void put2D(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, + uint32_t width, uint32_t height) { + fifo_.push(ChannelTrigger(TriggerData, dst, dstOffset, src, srcOffset, width, height, semaphoreId_).value); + } + + /// @brief Push a @ref TriggerData to the FIFO. + /// @param dst The destination memory region. + /// @param src The source memory region. + /// @param offset The common offset into the destination and source memory regions. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + __forceinline__ __device__ void put2D(MemoryId dst, MemoryId src, uint64_t offset, uint32_t width, uint32_t height) { + put2D(dst, offset, src, offset, width, height); + } + /// Push a @ref TriggerFlag to the FIFO. __forceinline__ __device__ void signal() { fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value); @@ -120,6 +186,19 @@ struct ProxyChannelDeviceHandle { fifo_.push(ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, size, semaphoreId_).value); } + /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. + /// @param dst The destination memory region. + /// @param dstOffset The offset into the destination memory region. + /// @param src The source memory region. + /// @param srcOffset The offset into the source memory region. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + __forceinline__ __device__ void put2DWithSignal(MemoryId dst, uint64_t dstOffset, MemoryId src, uint64_t srcOffset, + uint32_t width, uint32_t height) { + fifo_.push( + ChannelTrigger(TriggerData | TriggerFlag, dst, dstOffset, src, srcOffset, width, height, semaphoreId_).value); + } + /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. /// @param dst The destination memory region. /// @param src The source memory region. @@ -129,6 +208,17 @@ struct ProxyChannelDeviceHandle { putWithSignal(dst, offset, src, offset, size); } + /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. + /// @param dst The destination memory region. + /// @param src The source memory region. + /// @param offset The common offset into the destination and source memory regions. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + __forceinline__ __device__ void put2DWithSignal(MemoryId dst, MemoryId src, uint64_t offset, uint32_t width, + uint32_t height) { + put2DWithSignal(dst, offset, src, offset, width, height); + } + /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. /// @param dst The destination memory region. /// @param dstOffset The offset into the destination memory region. @@ -178,6 +268,15 @@ struct SimpleProxyChannelDeviceHandle { proxyChan_.put(dst_, dstOffset, src_, srcOffset, size); } + /// Push a @ref TriggerData to the FIFO. + /// @param dstOffset The offset into the destination memory region. + /// @param srcOffset The offset into the source memory region. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + __forceinline__ __device__ void put2D(uint64_t dstOffset, uint64_t srcOffset, uint32_t width, uint32_t height) { + proxyChan_.put2D(dst_, dstOffset, src_, srcOffset, width, height); + } + /// Push a @ref TriggerData to the FIFO. /// @param offset The common offset into the destination and source memory regions. /// @param size The size of the transfer. @@ -194,11 +293,29 @@ struct SimpleProxyChannelDeviceHandle { proxyChan_.putWithSignal(dst_, dstOffset, src_, srcOffset, size); } + /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. + /// @param dstOffset The offset into the destination memory region. + /// @param srcOffset The offset into the source memory region. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + __forceinline__ __device__ void put2DWithSignal(uint64_t dstOffset, uint64_t srcOffset, uint32_t width, + uint32_t height) { + proxyChan_.put2DWithSignal(dst_, dstOffset, src_, srcOffset, width, height); + } + /// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO. /// @param offset The common offset into the destination and source memory regions. /// @param size The size of the transfer. __forceinline__ __device__ void putWithSignal(uint64_t offset, uint64_t size) { putWithSignal(offset, offset, size); } + /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. + /// @param offset The common offset into the destination and source memory regions. + /// @param width The width of the 2D region. + /// @param height The height of the 2D region. + __forceinline__ __device__ void put2DWithSignal(uint64_t offset, uint32_t width, uint32_t height) { + put2DWithSignal(offset, offset, width, height); + } + /// Push a @ref TriggerData, a @ref TriggerFlag, and a @ref TriggerSync at the same time to the FIFO. /// @param dstOffset The offset into the destination memory region. /// @param srcOffset The offset into the source memory region. diff --git a/python/proxy_channel_py.cpp b/python/proxy_channel_py.cpp index a483f99d..5e264acf 100644 --- a/python/proxy_channel_py.cpp +++ b/python/proxy_channel_py.cpp @@ -19,7 +19,9 @@ void register_proxy_channel(nb::module_& m) { .def(nb::init<>()) .def("start_proxy", &ProxyService::startProxy) .def("stop_proxy", &ProxyService::stopProxy) - .def("build_and_add_semaphore", &ProxyService::buildAndAddSemaphore, nb::arg("comm"), nb::arg("connection")) + .def("build_and_add_semaphore", + nb::overload_cast>(&ProxyService::buildAndAddSemaphore), + nb::arg("comm"), nb::arg("connection")) .def("add_semaphore", &ProxyService::addSemaphore, nb::arg("semaphore")) .def("add_memory", &ProxyService::addMemory, nb::arg("memory")) .def("semaphore", &ProxyService::semaphore, nb::arg("id")) diff --git a/src/connection.cc b/src/connection.cc index 6bf9e5bc..6820ad33 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -3,7 +3,6 @@ #include "connection.hpp" -#include #include #include "debug.h" diff --git a/src/proxy_channel.cc b/src/proxy_channel.cc index 634ab8f2..cfe6862b 100644 --- a/src/proxy_channel.cc +++ b/src/proxy_channel.cc @@ -29,20 +29,21 @@ MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& com return semaphores_.size() - 1; } -MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr semaphore) { - semaphores_.push_back(semaphore); - return semaphores_.size() - 1; -} - -MSCCLPP_API_CPP SemaphoreId ProxyService::add2DChannel(std::shared_ptr connection, - std::pair pitch) { - semaphores_.push_back(std::make_shared(communicator_, connection)); +MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator, + std::shared_ptr connection, + std::pair pitch) { + semaphores_.push_back(std::make_shared(communicator, connection)); SemaphoreId id = semaphores_.size() - 1; if (id >= pitches_.size()) pitches_.resize(id + 1, std::pair(0, 0)); pitches_[id] = pitch; return id; } +MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr semaphore) { + semaphores_.push_back(semaphore); + return semaphores_.size() - 1; +} + MSCCLPP_API_CPP MemoryId ProxyService::addMemory(RegisteredMemory memory) { memories_.push_back(memory); return memories_.size() - 1; diff --git a/test/mp_unit/proxy_channel_tests.cu b/test/mp_unit/proxy_channel_tests.cu index 4ce1ccef..20a1069e 100644 --- a/test/mp_unit/proxy_channel_tests.cu +++ b/test/mp_unit/proxy_channel_tests.cu @@ -17,6 +17,12 @@ void ProxyChannelOneToOneTest::TearDown() { CommunicatorTestBase::TearDown(); } void ProxyChannelOneToOneTest::setupMeshConnections(std::vector& proxyChannels, bool useIbOnly, void* sendBuff, size_t sendBuffBytes, void* recvBuff, size_t recvBuffBytes) { + setupMeshConnections(proxyChannels, useIbOnly, sendBuff, sendBuffBytes, sendBuffBytes, recvBuff, recvBuffBytes); +} + +void ProxyChannelOneToOneTest::setupMeshConnections(std::vector& proxyChannels, + bool useIbOnly, void* sendBuff, size_t sendBuffBytes, size_t pitch, + void* recvBuff, size_t recvBuffBytes) { const int rank = communicator->bootstrap()->getRank(); const int worldSize = communicator->bootstrap()->getNranks(); const bool isInPlace = (recvBuff == nullptr); @@ -49,7 +55,12 @@ void ProxyChannelOneToOneTest::setupMeshConnections(std::vectorsetup(); - mscclpp::SemaphoreId cid = proxyService->buildAndAddSemaphore(*communicator, conn); + mscclpp::SemaphoreId cid; + if (sendBuffBytes == pitch) { + cid = proxyService->buildAndAddSemaphore(*communicator, conn); + } else { + cid = proxyService->buildAndAddSemaphore(*communicator, conn, std::pair(pitch, pitch)); + } communicator->setup(); proxyChannels.emplace_back(proxyService->proxyChannel(cid), proxyService->addMemory(remoteMemory.get()), @@ -230,7 +241,7 @@ TEST_F(ProxyChannelOneToOneTest, PingPongTile) { const int nElem = 4 * 1024 * 1024; - std::vector> proxyChannels; + std::vector proxyChannels; std::shared_ptr buff = mscclpp::allocSharedCuda(nElem); const int pitchSize = 512; // the buff tile is 8192x128 setupMeshConnections(proxyChannels, false, buff.get(), nElem * sizeof(int), pitchSize); @@ -239,7 +250,7 @@ TEST_F(ProxyChannelOneToOneTest, PingPongTile) { MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstProxyChans, proxyChannels.data(), sizeof(DeviceHandle))); - channelService->startProxy(); + proxyService->startProxy(); std::shared_ptr ret = mscclpp::makeSharedCudaHost(0);