mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
FIFO improvements (#557)
* Revert `MSCCLPP_FIFO_USE_TAIL_REPLICA=1` back to the default. * Optimize `FifoDeviceHandle`. * Do not use `cudaHostAllocWriteCombined` that increases latency. * Pin host memory for `Host2DeviceSemaphore::outboundSemaphore_`. * Fix proxy NUMA binding issues. * Prevent graph capture inside proxy threads. * Now `CudaIpcConnection` skips stream sync when unnecessary. * Now any type of connection needs to hold a shared pointer to the context for memory safety. * Now a context should be always managed by a shared pointer for memory safety. * Minor docs & interface improvements. * Minor fix in `mscclpp-test` correctness test.
This commit is contained in:
@@ -88,18 +88,13 @@ class MyProxyService {
|
||||
std::vector<std::shared_ptr<mscclpp::Host2DeviceSemaphore>> deviceSemaphores2_;
|
||||
std::vector<std::shared_ptr<mscclpp::Connection>> connections_;
|
||||
mscclpp::Proxy proxy_;
|
||||
int deviceNumaNode_;
|
||||
|
||||
public:
|
||||
MyProxyService(mscclpp::Communicator& comm, int* data_d, int dataSize)
|
||||
: dataSize_(dataSize),
|
||||
remoteMemories_(world_size),
|
||||
connections_(world_size),
|
||||
proxy_([&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }, [&]() { bindThread(); }) {
|
||||
int cudaDevice;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
|
||||
deviceNumaNode_ = mscclpp::getDeviceNumaNode(cudaDevice);
|
||||
|
||||
proxy_([&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); }) {
|
||||
int thisNode = rankToNode(rank);
|
||||
int cudaNum = rankToLocalRank(rank);
|
||||
std::string ibDevStr = "mlx5_ib" + std::to_string(cudaNum);
|
||||
@@ -144,12 +139,6 @@ class MyProxyService {
|
||||
}
|
||||
}
|
||||
|
||||
void bindThread() {
|
||||
if (deviceNumaNode_ >= 0) {
|
||||
mscclpp::numaBind(deviceNumaNode_);
|
||||
}
|
||||
}
|
||||
|
||||
mscclpp::ProxyHandlerResult handleTrigger(mscclpp::ProxyTrigger triggerRaw) {
|
||||
static int flusher = 0;
|
||||
if (triggerRaw.fst > 0) {
|
||||
@@ -176,7 +165,7 @@ class MyProxyService {
|
||||
|
||||
void stop() { proxy_.stop(); }
|
||||
|
||||
mscclpp::Fifo& fifo() { return proxy_.fifo(); }
|
||||
std::shared_ptr<mscclpp::Fifo> fifo() { return proxy_.fifo(); }
|
||||
|
||||
mscclpp::Host2DeviceSemaphore::DeviceHandle getDeviceHandle1(int r) { return deviceSemaphores1_[r]->deviceHandle(); }
|
||||
|
||||
@@ -249,7 +238,7 @@ int main(int argc, char* argv[]) {
|
||||
|
||||
if (rank == 0) printf("Launching MSCCL++ proxy threads\n");
|
||||
proxyService.start();
|
||||
mscclpp::FifoDeviceHandle fifo = proxyService.fifo().deviceHandle();
|
||||
mscclpp::FifoDeviceHandle fifo = proxyService.fifo()->deviceHandle();
|
||||
if (rank == 0) printf("Testing the correctness of AllGather implementation\n");
|
||||
cudaStream_t stream;
|
||||
MSCCLPP_CUDATHROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
|
||||
@@ -536,16 +536,11 @@ class AllGatherProxyService : public mscclpp::BaseProxyService {
|
||||
};
|
||||
|
||||
AllGatherProxyService::AllGatherProxyService(int worldSize, int rank, int cudaDevice)
|
||||
: worldSize_(worldSize),
|
||||
rank_(rank),
|
||||
cudaDevice_(cudaDevice),
|
||||
sendBytes_(0),
|
||||
proxy_(
|
||||
std::make_shared<mscclpp::Proxy>([&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); },
|
||||
[&]() {
|
||||
int deviceNumaNode = getDeviceNumaNode(cudaDevice_);
|
||||
numaBind(deviceNumaNode);
|
||||
})) {}
|
||||
: worldSize_(worldSize), rank_(rank), cudaDevice_(cudaDevice), sendBytes_(0) {
|
||||
MSCCLPP_CUDATHROW(cudaSetDevice(cudaDevice));
|
||||
auto handlerFunc = [&](mscclpp::ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); };
|
||||
proxy_ = std::make_shared<mscclpp::Proxy>(handlerFunc);
|
||||
}
|
||||
|
||||
mscclpp::ProxyHandlerResult AllGatherProxyService::handleTrigger(mscclpp::ProxyTrigger triggerRaw) {
|
||||
size_t offset = rank_ * sendBytes_;
|
||||
|
||||
@@ -275,6 +275,7 @@ void BaseTestEngine::runTest() {
|
||||
if (args_.reportErrors) {
|
||||
this->coll_->setupCollTest(args_, size);
|
||||
this->coll_->initData(this->args_, this->getSendBuff(), this->getExpectedBuff());
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
this->barrier();
|
||||
this->coll_->runColl(args_, stream_);
|
||||
CUDATHROW(cudaDeviceSynchronize());
|
||||
|
||||
Reference in New Issue
Block a user