Channels work

This commit is contained in:
Olli Saarikivi
2023-05-03 17:11:25 +00:00
parent 6002a520b6
commit 81e7d1b344
7 changed files with 101 additions and 10 deletions

View File

@@ -121,7 +121,7 @@ LDFLAGS := $(NVLDFLAGS) $(GDRCOPY_LDFLAGS) -libverbs -lnuma
LIBSRCS := $(addprefix src/,debug.cc utils.cc init.cc proxy.cc ib.cc config.cc)
LIBSRCS += $(addprefix src/bootstrap/,bootstrap.cc socket.cc)
LIBSRCS += $(addprefix src/,communicator.cc connection.cc registered_memory.cc)
LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc)
LIBSRCS += $(addprefix src/,epoch.cc proxy_cpp.cc fifo.cc channel.cc)
ifneq ($(NPKIT), 0)
LIBSRCS += $(addprefix src/misc/,npkit.cc)
endif

26
src/channel.cc Normal file
View File

@@ -0,0 +1,26 @@
#include "channel.hpp"
#include "utils.h"
#include "checks.hpp"
#include "api.h"
#include "debug.h"
namespace mscclpp {
namespace channel {
MSCCLPP_API_CPP DeviceChannelService::DeviceChannelService(Communicator& communicator) : communicator_(communicator),
proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) {
int cudaDevice;
CUDATHROW(cudaGetDevice(&cudaDevice));
MSCCLPPTHROW(getDeviceNumaNode(cudaDevice, &deviceNumaNode));
}
MSCCLPP_API_CPP void DeviceChannelService::bindThread()
{
if (deviceNumaNode >= 0) {
MSCCLPPTHROW(numaBind(deviceNumaNode));
INFO(MSCCLPP_INIT, "NUMA node of DeviceChannelService proxy thread is set to %d", deviceNumaNode);
}
}
} // namespace channel
} // namespace mscclpp

View File

@@ -4,6 +4,7 @@
#include "infiniband/verbs.h"
#include "npkit/npkit.h"
#include "registered_memory.hpp"
#include "utils.hpp"
namespace mscclpp {
@@ -33,7 +34,7 @@ int ConnectionBase::tag() { return tag_; }
CudaIpcConnection::CudaIpcConnection(int remoteRank, int tag) : ConnectionBase(remoteRank, tag)
{
cudaStreamCreate(&stream);
CUDATHROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
}
CudaIpcConnection::~CudaIpcConnection()
@@ -54,6 +55,7 @@ Transport CudaIpcConnection::remoteTransport()
void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size)
{
ScopedTimer timer("CudaIpcConnection::write");
validateTransport(dst, remoteTransport());
validateTransport(src, transport());

View File

@@ -5,6 +5,7 @@
#include "mscclpp.hpp"
#include "proxy.hpp"
#include "mscclppfifo.hpp"
#include "utils.hpp"
namespace mscclpp {
namespace channel {
@@ -177,7 +178,7 @@ inline ProxyHandler makeChannelProxyHandler(DeviceChannelService& channelService
class DeviceChannelService {
public:
DeviceChannelService(Communicator& communicator) : communicator_(communicator), proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }) {}
DeviceChannelService(Communicator& communicator);
ChannelId addChannel(std::shared_ptr<Connection> connection) {
channels_.push_back(Channel(communicator_, connection));
@@ -200,6 +201,9 @@ private:
std::vector<Channel> channels_;
std::vector<RegisteredMemory> memories_;
Proxy proxy_;
int deviceNumaNode;
void bindThread();
ProxyHandlerResult handleTrigger(ProxyTrigger triggerRaw) {
ChannelTrigger* trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);

View File

@@ -21,12 +21,11 @@ using ProxyHandler = std::function<ProxyHandlerResult(ProxyTrigger)>;
class Proxy
{
public:
Proxy(ProxyHandler handler, std::function<void()> threadInit);
Proxy(ProxyHandler handler);
~Proxy();
void start();
void stop();
HostProxyFifo& fifo();

54
src/include/utils.hpp Normal file
View File

@@ -0,0 +1,54 @@
#ifndef MSCCLPP_UTILS_HPP_
#define MSCCLPP_UTILS_HPP_
#include <chrono>
#include <stdio.h>
namespace mscclpp {
struct Timer
{
std::chrono::steady_clock::time_point start;
Timer()
{
start = std::chrono::steady_clock::now();
}
int64_t elapsed()
{
auto end = std::chrono::steady_clock::now();
return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
}
void reset()
{
start = std::chrono::steady_clock::now();
}
void print(const char* name)
{
auto end = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
printf("%s: %ld us\n", name, elapsed);
}
};
struct ScopedTimer
{
Timer timer;
const char* name;
ScopedTimer(const char* name) : name(name)
{
}
~ScopedTimer()
{
timer.print(name);
}
};
} // namespace mscclpp
#endif // MSCCLPP_UTILS_HPP_

View File

@@ -2,6 +2,7 @@
#include "api.h"
#include "mscclpp.hpp"
#include "utils.h"
#include "utils.hpp"
#include <atomic>
#include <thread>
@@ -14,18 +15,23 @@ const int ProxyFlushPeriod = 4;
struct Proxy::Impl
{
ProxyHandler handler;
std::function<void()> threadInit;
HostProxyFifo fifo;
std::thread service;
std::atomic_bool running;
Impl(ProxyHandler handler) : handler(handler), running(false)
Impl(ProxyHandler handler, std::function<void()> threadInit) : handler(handler), threadInit(threadInit), running(false)
{
}
};
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler)
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function<void()> threadInit)
{
pimpl = std::make_unique<Impl>(handler, threadInit);
}
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) : Proxy(handler, [] {})
{
pimpl = std::make_unique<Impl>(handler);
}
MSCCLPP_API_CPP Proxy::~Proxy()
@@ -39,8 +45,8 @@ MSCCLPP_API_CPP void Proxy::start()
{
pimpl->running = true;
pimpl->service = std::thread([this] {
// from this point on, proxy thread will stay close to the device
// PROXYMSCCLPPCHECK(numaBind(pimpl->comm->devNumaNode)); // TODO: reenable this
pimpl->threadInit();
ProxyHandler handler = this->pimpl->handler;
HostProxyFifo& fifo = this->pimpl->fifo;