diff --git a/Makefile b/Makefile index 78b993cf..2b80afb5 100644 --- a/Makefile +++ b/Makefile @@ -121,7 +121,7 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc) LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc) LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc) -LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc) +LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc channel.cc) ifneq ($(NPKIT), 0) LIBSRCS += $(addprefix src/misc/,npkit.cc) endif diff --git a/src/channel.cc b/src/channel.cc new file mode 100644 index 00000000..42572390 --- /dev/null +++ b/src/channel.cc @@ -0,0 +1,26 @@ +#include "channel.hpp" +#include "utils.h" +#include "checks.hpp" +#include "api.h" +#include "debug.h" + +namespace mscclpp { +namespace channel { + +MSCCLPP_API_CPP DeviceChannelService::DeviceChannelService(Communicator& communicator) : communicator_(communicator), + proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) { + int cudaDevice; + CUDATHROW(cudaGetDevice(&cudaDevice)); + MSCCLPPTHROW(getDeviceNumaNode(cudaDevice, &deviceNumaNode)); +} + +MSCCLPP_API_CPP void DeviceChannelService::bindThread() +{ + if (deviceNumaNode >= 0) { + MSCCLPPTHROW(numaBind(deviceNumaNode)); + INFO(MSCCLPP_INIT, "NUMA node of DeviceChannelService proxy thread is set to %d", deviceNumaNode); + } +} + +} // namespace channel +} // namespace mscclpp \ No newline at end of file diff --git a/src/connection.cc b/src/connection.cc index 66c54f06..0dee770b 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -4,6 +4,7 @@ #include "infiniband/verbs.h" #include "npkit/npkit.h" #include "registered_memory.hpp" +#include "utils.hpp" namespace mscclpp { @@ -33,7 +34,7 @@ int ConnectionBase::tag() { return tag_; } CudaIpcConnection::CudaIpcConnection(int remoteRank, int tag) : ConnectionBase(remoteRank, tag) { - cudaStreamCreate(&stream); + CUDATHROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); } CudaIpcConnection::~CudaIpcConnection() @@ -54,6 +55,7 @@ Transport CudaIpcConnection::remoteTransport() void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t size) { + ScopedTimer timer("CudaIpcConnection::write"); validateTransport(dst, remoteTransport()); validateTransport(src, transport()); diff --git a/src/include/channel.hpp b/src/include/channel.hpp index 42826f4f..eb4bd9e7 100644 --- a/src/include/channel.hpp +++ b/src/include/channel.hpp @@ -5,6 +5,7 @@ #include "mscclpp.hpp" #include "proxy.hpp" #include "mscclppfifo.hpp" +#include "utils.hpp" namespace mscclpp { namespace channel { @@ -177,7 +178,7 @@ inline ProxyHandler makeChannelProxyHandler(DeviceChannelService& channelService class DeviceChannelService { public: - DeviceChannelService(Communicator& communicator) : communicator_(communicator), proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }) {} + DeviceChannelService(Communicator& communicator); ChannelId addChannel(std::shared_ptr connection) { channels_.push_back(Channel(communicator_, connection)); @@ -200,6 +201,9 @@ private: std::vector channels_; std::vector memories_; Proxy proxy_; + int deviceNumaNode; + + void bindThread(); ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw) { ChannelTrigger* trigger = reinterpret_cast(&triggerRaw); diff --git a/src/include/proxy.hpp b/src/include/proxy.hpp index f913beac..51ae4752 100644 --- a/src/include/proxy.hpp +++ b/src/include/proxy.hpp @@ -21,12 +21,11 @@ using ProxyHandler = std::function; class Proxy { public: + Proxy(ProxyHandler handler, std::function threadInit); Proxy(ProxyHandler handler); - ~Proxy(); void start(); - void stop(); HostProxyFifo& fifo(); diff --git a/src/include/utils.hpp b/src/include/utils.hpp new file mode 100644 index 00000000..9abf9994 --- /dev/null +++ b/src/include/utils.hpp @@ -0,0 +1,54 @@ +#ifndef MSCCLPP_UTILS_HPP_ +#define MSCCLPP_UTILS_HPP_ + +#include +#include + +namespace mscclpp { + +struct Timer +{ + std::chrono::steady_clock::time_point start; + + Timer() + { + start = std::chrono::steady_clock::now(); + } + + int64_t elapsed() + { + auto end = std::chrono::steady_clock::now(); + return std::chrono::duration_cast(end - start).count(); + } + + void reset() + { + start = std::chrono::steady_clock::now(); + } + + void print(const char* name) + { + auto end = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast(end - start).count(); + printf("%s: %ld us\n", name, elapsed); + } +}; + +struct ScopedTimer +{ + Timer timer; + const char* name; + + ScopedTimer(const char* name) : name(name) + { + } + + ~ScopedTimer() + { + timer.print(name); + } +}; + +} // namespace mscclpp + +#endif // MSCCLPP_UTILS_HPP_ diff --git a/src/proxy_cpp.cc b/src/proxy_cpp.cc index 2fb8c2b0..b1626813 100644 --- a/src/proxy_cpp.cc +++ b/src/proxy_cpp.cc @@ -2,6 +2,7 @@ #include "api.h" #include "mscclpp.hpp" #include "utils.h" +#include "utils.hpp" #include #include @@ -14,18 +15,23 @@ const int ProxyFlushPeriod = 4; struct Proxy::Impl { ProxyHandler handler; + std::function threadInit; HostProxyFifo fifo; std::thread service; std::atomic_bool running; - Impl(ProxyHandler handler) : handler(handler), running(false) + Impl(ProxyHandler handler, std::function threadInit) : handler(handler), threadInit(threadInit), running(false) { } }; -MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) +MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function threadInit) +{ + pimpl = std::make_unique(handler, threadInit); +} + +MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) : Proxy(handler, [] {}) { - pimpl = std::make_unique(handler); } MSCCLPP_API_CPP Proxy::~Proxy() @@ -39,8 +45,8 @@ MSCCLPP_API_CPP void Proxy::start() { pimpl->running = true; pimpl->service = std::thread([this] { - // from this point on, proxy thread will stay close to the device - // PROXYMSCCLPPCHECK(numaBind(pimpl->comm->devNumaNode)); // TODO: reenable this + + pimpl->threadInit(); ProxyHandler handler = this->pimpl->handler; HostProxyFifo& fifo = this->pimpl->fifo;