Files
mscclpp/src/proxy_channel.cc
Saeed Maleki 8d1b984bed Change device handle interfaces & others (#142)
* Changed device handle interfaces
* Changed proxy service interfaces
* Move device code into separate files
* Fixed FIFO polling issues
* Add configuration arguments in several interface functions

---------

Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
Co-authored-by: Binyang Li <binyli@microsoft.com>
Co-authored-by: root <root@a100-saemal0.qxveptpukjsuthqvv514inp03c.gx.internal.cloudapp.net>
2023-08-16 20:00:56 +08:00

95 lines
3.4 KiB
C++

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <mscclpp/numa.hpp>
#include <mscclpp/proxy_channel.hpp>
#include "api.h"
#include "debug.h"
namespace mscclpp {
MSCCLPP_API_CPP ProxyChannel::ProxyChannel(SemaphoreId semaphoreId, Host2DeviceSemaphore::DeviceHandle semaphore,
FifoDeviceHandle fifo)
: semaphoreId_(semaphoreId), semaphore_(semaphore), fifo_(fifo) {}
MSCCLPP_API_CPP SimpleProxyChannel::SimpleProxyChannel(ProxyChannel proxyChan, MemoryId dst, MemoryId src)
: proxyChan_(proxyChan), dst_(dst), src_(src) {}
MSCCLPP_API_CPP ProxyService::ProxyService()
: proxy_([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) {
int cudaDevice;
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
deviceNumaNode = getDeviceNumaNode(cudaDevice);
}
MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator,
std::shared_ptr<Connection> connection) {
semaphores_.push_back(std::make_shared<Host2DeviceSemaphore>(communicator, connection));
return semaphores_.size() - 1;
}
MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr<Host2DeviceSemaphore> semaphore) {
semaphores_.push_back(semaphore);
return semaphores_.size() - 1;
}
MSCCLPP_API_CPP MemoryId ProxyService::addMemory(RegisteredMemory memory) {
memories_.push_back(memory);
return memories_.size() - 1;
}
MSCCLPP_API_CPP std::shared_ptr<Host2DeviceSemaphore> ProxyService::semaphore(SemaphoreId id) const {
return semaphores_[id];
}
MSCCLPP_API_CPP ProxyChannel ProxyService::proxyChannel(SemaphoreId id) {
return ProxyChannel(id, semaphores_[id]->deviceHandle(), proxy_.fifo().deviceHandle());
}
MSCCLPP_API_CPP void ProxyService::startProxy() { proxy_.start(); }
MSCCLPP_API_CPP void ProxyService::stopProxy() { proxy_.stop(); }
MSCCLPP_API_CPP void ProxyService::bindThread() {
if (deviceNumaNode >= 0) {
numaBind(deviceNumaNode);
INFO(MSCCLPP_INIT, "NUMA node of ProxyService proxy thread is set to %d", deviceNumaNode);
}
}
ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
ChannelTrigger* trigger = reinterpret_cast<ChannelTrigger*>(&triggerRaw);
std::shared_ptr<Host2DeviceSemaphore> semaphore = semaphores_[trigger->fields.chanId];
auto result = ProxyHandlerResult::Continue;
if (trigger->fields.type & TriggerData) {
RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId];
RegisteredMemory& src = memories_[trigger->fields.srcMemoryId];
semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset,
trigger->fields.size);
}
if (trigger->fields.type & TriggerFlag) {
semaphore->signal();
}
if (trigger->fields.type & TriggerSync) {
semaphore->connection()->flush();
result = ProxyHandlerResult::FlushFifoTailAndContinue;
}
return result;
}
MSCCLPP_API_CPP ProxyChannel::DeviceHandle ProxyChannel::deviceHandle() const {
return ProxyChannel::DeviceHandle{.semaphoreId_ = semaphoreId_, .semaphore_ = semaphore_, .fifo_ = fifo_};
}
MSCCLPP_API_CPP SimpleProxyChannel::DeviceHandle SimpleProxyChannel::deviceHandle() const {
return SimpleProxyChannel::DeviceHandle{.proxyChan_ = proxyChan_.deviceHandle(), .dst_ = dst_, .src_ = src_};
}
} // namespace mscclpp