Construct ProxyChannel with shared pointers (#184)

This commit is contained in:
Changho Hwang
2023-09-18 13:46:23 +08:00
committed by GitHub
parent a6b24dcbed
commit 6c0ee72916
5 changed files with 29 additions and 26 deletions

View File

@@ -28,6 +28,9 @@ class Proxy {
void start();
void stop();
/// This is a concurrent fifo which is multiple threads from the device
/// can produce for and the sole proxy thread consumes it.
/// @return the fifo
Fifo& fifo();
private:

View File

@@ -62,7 +62,7 @@ class ProxyService : public BaseProxyService {
private:
std::vector<std::shared_ptr<Host2DeviceSemaphore>> semaphores_;
std::vector<RegisteredMemory> memories_;
Proxy proxy_;
std::shared_ptr<Proxy> proxy_;
int deviceNumaNode;
void bindThread();
@@ -75,16 +75,14 @@ struct ProxyChannel {
private:
SemaphoreId semaphoreId_;
Host2DeviceSemaphore::DeviceHandle semaphore_;
std::shared_ptr<Host2DeviceSemaphore> semaphore_;
// this is a concurrent fifo which is multiple threads from the device
// can produce for and the sole proxy thread consumes it.
FifoDeviceHandle fifo_;
std::shared_ptr<Proxy> proxy_;
public:
ProxyChannel() = default;
ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore, FifoDeviceHandle fifo);
ProxyChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore, std::shared_ptr<Proxy> proxy);
ProxyChannel(const ProxyChannel& other) = default;

View File

@@ -26,8 +26,8 @@ void register_proxy_channel(nb::module_& m) {
.def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id"));
nb::class_<ProxyChannel>(m, "ProxyChannel")
.def(nb::init<SemaphoreId, Host2DeviceSemaphore::DeviceHandle, FifoDeviceHandle>(), nb::arg("semaphoreId"),
nb::arg("semaphore"), nb::arg("fifo"))
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"))
.def("device_handle", &ProxyChannel::deviceHandle);
nb::class_<ProxyChannel::DeviceHandle>(m, "ProxyChannelDeviceHandle")

View File

@@ -9,15 +9,16 @@
namespace mscclpp {
MSCCLPP_API_CPP ProxyChannel::ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore,
FifoDeviceHandle fifo)
: semaphoreId_(semaphoreId), semaphore_(semaphore), fifo_(fifo) {}
MSCCLPP_API_CPP ProxyChannel::ProxyChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore,
std::shared_ptr<Proxy> proxy)
: semaphoreId_(semaphoreId), semaphore_(semaphore), proxy_(proxy) {}
MSCCLPP_API_CPP SimpleProxyChannel::SimpleProxyChannel(ProxyChannel proxyChan, MemoryId dst, MemoryId src)
: proxyChan_(proxyChan), dst_(dst), src_(src) {}
MSCCLPP_API_CPP ProxyService::ProxyService()
: proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) {
: proxy_(std::make_shared<Proxy>([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); },
[&]() { bindThread(); })) {
int cudaDevice;
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
deviceNumaNode = getDeviceNumaNode(cudaDevice);
@@ -44,12 +45,12 @@ MSCCLPP_API_CPP std::shared_ptr<Host2DeviceSemaphore> ProxyService::semaphore(Se
}
MSCCLPP_API_CPP ProxyChannel ProxyService::proxyChannel(SemaphoreId id) {
return ProxyChannel(id, semaphores_[id]->deviceHandle(), proxy_.fifo().deviceHandle());
return ProxyChannel(id, semaphores_[id], proxy_);
}
MSCCLPP_API_CPP void ProxyService::startProxy() { proxy_.start(); }
MSCCLPP_API_CPP void ProxyService::startProxy() { proxy_->start(); }
MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_.stop(); }
MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_->stop(); }
MSCCLPP_API_CPP void ProxyService::bindThread() {
if (deviceNumaNode >= 0) {
@@ -84,7 +85,8 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
}
MSCCLPP_API_CPP ProxyChannel::DeviceHandle ProxyChannel::deviceHandle() const {
return ProxyChannel::DeviceHandle{.semaphoreId_ = semaphoreId_, .semaphore_ = semaphore_, .fifo_ = fifo_};
return ProxyChannel::DeviceHandle{
.semaphoreId_ = semaphoreId_, .semaphore_ = semaphore_->deviceHandle(), .fifo_ = proxy_->fifo().deviceHandle()};
}
MSCCLPP_API_CPP SimpleProxyChannel::DeviceHandle SimpleProxyChannel::deviceHandle() const {

View File

@@ -281,8 +281,8 @@ __global__ void allgather4(int rank, int worldSize, int nRanksPerNode, size_t ne
class AllGatherProxyService : public mscclpp::BaseProxyService {
public:
AllGatherProxyService(int worldSize, int rank, int cudaDevice);
void startProxy() override { proxy_.start(); }
void stopProxy() override { proxy_.stop(); }
void startProxy() override { proxy_->start(); }
void stopProxy() override { proxy_->stop(); }
void setSendBytes(size_t sendBytes) { this->sendBytes_ = sendBytes; }
void addRemoteMemory(mscclpp::RegisteredMemory memory) { remoteMemories_.push_back(memory); }
void setLocalMemory(mscclpp::RegisteredMemory memory) { localMemory_ = memory; }
@@ -294,8 +294,7 @@ class AllGatherProxyService : public mscclpp::BaseProxyService {
std::vector<DeviceHandle<mscclpp::ProxyChannel>> proxyChannels() {
std::vector<DeviceHandle<mscclpp::ProxyChannel>> result;
for (auto& semaphore : semaphores_) {
result.push_back(
mscclpp::deviceHandle(mscclpp::ProxyChannel(0, semaphore->deviceHandle(), proxy_.fifo().deviceHandle())));
result.push_back(mscclpp::deviceHandle(mscclpp::ProxyChannel(0, semaphore, proxy_)));
}
return result;
}
@@ -306,7 +305,7 @@ class AllGatherProxyService : public mscclpp::BaseProxyService {
int cudaDevice_;
size_t sendBytes_;
mscclpp::Proxy proxy_;
std::shared_ptr<mscclpp::Proxy> proxy_;
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores_;
std::vector<mscclpp::RegisteredMemory> remoteMemories_;
mscclpp::RegisteredMemory localMemory_;
@@ -319,11 +318,12 @@ AllGatherProxyService::AllGatherProxyService(int worldSize, int rank, int cudaDe
sendBytes_(0),
rank_(rank),
cudaDevice_(cudaDevice),
proxy_([&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); },
[&]() {
int deviceNumaNode = getDeviceNumaNode(cudaDevice_);
numaBind(deviceNumaNode);
}) {}
proxy_(
std::make_shared<mscclpp::Proxy>([&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); },
[&]() {
int deviceNumaNode = getDeviceNumaNode(cudaDevice_);
numaBind(deviceNumaNode);
})) {}
mscclpp::ProxyHandlerResult AllGatherProxyService::handleTrigger(mscclpp::ProxyTrigger triggerRaw) {
size_t offset = rank_ * sendBytes_;