mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Construct ProxyChannel with shared pointers (#184)
This commit is contained in:
@@ -28,6 +28,9 @@ class Proxy {
|
||||
void start();
|
||||
void stop();
|
||||
|
||||
/// This is a concurrent fifo which is multiple threads from the device
|
||||
/// can produce for and the sole proxy thread consumes it.
|
||||
/// @return the fifo
|
||||
Fifo& fifo();
|
||||
|
||||
private:
|
||||
|
||||
@@ -62,7 +62,7 @@ class ProxyService : public BaseProxyService {
|
||||
private:
|
||||
std::vector<std::shared_ptr<Host2DeviceSemaphore>> semaphores_;
|
||||
std::vector<RegisteredMemory> memories_;
|
||||
Proxy proxy_;
|
||||
std::shared_ptr<Proxy> proxy_;
|
||||
int deviceNumaNode;
|
||||
|
||||
void bindThread();
|
||||
@@ -75,16 +75,14 @@ struct ProxyChannel {
|
||||
private:
|
||||
SemaphoreId semaphoreId_;
|
||||
|
||||
Host2DeviceSemaphore::DeviceHandle semaphore_;
|
||||
std::shared_ptr<Host2DeviceSemaphore> semaphore_;
|
||||
|
||||
// this is a concurrent fifo which is multiple threads from the device
|
||||
// can produce for and the sole proxy thread consumes it.
|
||||
FifoDeviceHandle fifo_;
|
||||
std::shared_ptr<Proxy> proxy_;
|
||||
|
||||
public:
|
||||
ProxyChannel() = default;
|
||||
|
||||
ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore, FifoDeviceHandle fifo);
|
||||
ProxyChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore, std::shared_ptr<Proxy> proxy);
|
||||
|
||||
ProxyChannel(const ProxyChannel& other) = default;
|
||||
|
||||
|
||||
@@ -26,8 +26,8 @@ void register_proxy_channel(nb::module_& m) {
|
||||
.def("proxy_channel", &ProxyService::proxyChannel, nb::arg("id"));
|
||||
|
||||
nb::class_<ProxyChannel>(m, "ProxyChannel")
|
||||
.def(nb::init<SemaphoreId, Host2DeviceSemaphore::DeviceHandle, FifoDeviceHandle>(), nb::arg("semaphoreId"),
|
||||
nb::arg("semaphore"), nb::arg("fifo"))
|
||||
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
|
||||
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"))
|
||||
.def("device_handle", &ProxyChannel::deviceHandle);
|
||||
|
||||
nb::class_<ProxyChannel::DeviceHandle>(m, "ProxyChannelDeviceHandle")
|
||||
|
||||
@@ -9,15 +9,16 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
MSCCLPP_API_CPP ProxyChannel::ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore,
|
||||
FifoDeviceHandle fifo)
|
||||
: semaphoreId_(semaphoreId), semaphore_(semaphore), fifo_(fifo) {}
|
||||
MSCCLPP_API_CPP ProxyChannel::ProxyChannel(SemaphoreId semaphoreId, std::shared_ptr<Host2DeviceSemaphore> semaphore,
|
||||
std::shared_ptr<Proxy> proxy)
|
||||
: semaphoreId_(semaphoreId), semaphore_(semaphore), proxy_(proxy) {}
|
||||
|
||||
MSCCLPP_API_CPP SimpleProxyChannel::SimpleProxyChannel(ProxyChannel proxyChan, MemoryId dst, MemoryId src)
|
||||
: proxyChan_(proxyChan), dst_(dst), src_(src) {}
|
||||
|
||||
MSCCLPP_API_CPP ProxyService::ProxyService()
|
||||
: proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) {
|
||||
: proxy_(std::make_shared<Proxy>([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); },
|
||||
[&]() { bindThread(); })) {
|
||||
int cudaDevice;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
|
||||
deviceNumaNode = getDeviceNumaNode(cudaDevice);
|
||||
@@ -44,12 +45,12 @@ MSCCLPP_API_CPP std::shared_ptr<Host2DeviceSemaphore> ProxyService::semaphore(Se
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP ProxyChannel ProxyService::proxyChannel(SemaphoreId id) {
|
||||
return ProxyChannel(id, semaphores_[id]->deviceHandle(), proxy_.fifo().deviceHandle());
|
||||
return ProxyChannel(id, semaphores_[id], proxy_);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP void ProxyService::startProxy() { proxy_.start(); }
|
||||
MSCCLPP_API_CPP void ProxyService::startProxy() { proxy_->start(); }
|
||||
|
||||
MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_.stop(); }
|
||||
MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_->stop(); }
|
||||
|
||||
MSCCLPP_API_CPP void ProxyService::bindThread() {
|
||||
if (deviceNumaNode >= 0) {
|
||||
@@ -84,7 +85,8 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP ProxyChannel::DeviceHandle ProxyChannel::deviceHandle() const {
|
||||
return ProxyChannel::DeviceHandle{.semaphoreId_ = semaphoreId_, .semaphore_ = semaphore_, .fifo_ = fifo_};
|
||||
return ProxyChannel::DeviceHandle{
|
||||
.semaphoreId_ = semaphoreId_, .semaphore_ = semaphore_->deviceHandle(), .fifo_ = proxy_->fifo().deviceHandle()};
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP SimpleProxyChannel::DeviceHandle SimpleProxyChannel::deviceHandle() const {
|
||||
|
||||
@@ -281,8 +281,8 @@ __global__ void allgather4(int rank, int worldSize, int nRanksPerNode, size_t ne
|
||||
class AllGatherProxyService : public mscclpp::BaseProxyService {
|
||||
public:
|
||||
AllGatherProxyService(int worldSize, int rank, int cudaDevice);
|
||||
void startProxy() override { proxy_.start(); }
|
||||
void stopProxy() override { proxy_.stop(); }
|
||||
void startProxy() override { proxy_->start(); }
|
||||
void stopProxy() override { proxy_->stop(); }
|
||||
void setSendBytes(size_t sendBytes) { this->sendBytes_ = sendBytes; }
|
||||
void addRemoteMemory(mscclpp::RegisteredMemory memory) { remoteMemories_.push_back(memory); }
|
||||
void setLocalMemory(mscclpp::RegisteredMemory memory) { localMemory_ = memory; }
|
||||
@@ -294,8 +294,7 @@ class AllGatherProxyService : public mscclpp::BaseProxyService {
|
||||
std::vector<DeviceHandle<mscclpp::ProxyChannel>> proxyChannels() {
|
||||
std::vector<DeviceHandle<mscclpp::ProxyChannel>> result;
|
||||
for (auto& semaphore : semaphores_) {
|
||||
result.push_back(
|
||||
mscclpp::deviceHandle(mscclpp::ProxyChannel(0, semaphore->deviceHandle(), proxy_.fifo().deviceHandle())));
|
||||
result.push_back(mscclpp::deviceHandle(mscclpp::ProxyChannel(0, semaphore, proxy_)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -306,7 +305,7 @@ class AllGatherProxyService : public mscclpp::BaseProxyService {
|
||||
int cudaDevice_;
|
||||
size_t sendBytes_;
|
||||
|
||||
mscclpp::Proxy proxy_;
|
||||
std::shared_ptr<mscclpp::Proxy> proxy_;
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> semaphores_;
|
||||
std::vector<mscclpp::RegisteredMemory> remoteMemories_;
|
||||
mscclpp::RegisteredMemory localMemory_;
|
||||
@@ -319,11 +318,12 @@ AllGatherProxyService::AllGatherProxyService(int worldSize, int rank, int cudaDe
|
||||
sendBytes_(0),
|
||||
rank_(rank),
|
||||
cudaDevice_(cudaDevice),
|
||||
proxy_([&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); },
|
||||
[&]() {
|
||||
int deviceNumaNode = getDeviceNumaNode(cudaDevice_);
|
||||
numaBind(deviceNumaNode);
|
||||
}) {}
|
||||
proxy_(
|
||||
std::make_shared<mscclpp::Proxy>([&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); },
|
||||
[&]() {
|
||||
int deviceNumaNode = getDeviceNumaNode(cudaDevice_);
|
||||
numaBind(deviceNumaNode);
|
||||
})) {}
|
||||
|
||||
mscclpp::ProxyHandlerResult AllGatherProxyService::handleTrigger(mscclpp::ProxyTrigger triggerRaw) {
|
||||
size_t offset = rank_ * sendBytes_;
|
||||
|
||||
Reference in New Issue
Block a user