mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
I recently encountered a weird memory usage issue. After starting the proxy service on a cuda device X > 0, I notice an unexpected thread entity apprear on both the GPU X and GPU 0, where GPU 0's share is about 500MB. Note that when the device is 0, there is no extra memory usage. The image clearly shows that when 8 ranks each using one GPU and starting proxies, the GPU 0 sees 7 extra threads, each consuming 500MB extra memory. <img width="1247" height="1367" alt="Screenshot 2026-02-28 000153" src="https://github.com/user-attachments/assets/cfd0d47f-319b-4ebb-bf19-dec66062e6f4" /> After tracking down to when it happens, I identified the root cause in Proxy thread initialization. // never capture in a proxy thread auto mode = cudaStreamCaptureModeRelaxed; MSCCLPP_CUDATHROW(cudaThreadExchangeStreamCaptureMode(&mode)); pimpl_->threadInit(); The call to cudaThreadExchangeStreamCaptureMode() actually triggers some resource allocation on the "current device" which is still 0 for the starting thread. The later threadInit() is too late to set the correct GPU number. The fix is simple: call threadInit() before the first cuda call: pimpl_->threadInit(); // never capture in a proxy thread auto mode = cudaStreamCaptureModeRelaxed; MSCCLPP_CUDATHROW(cudaThreadExchangeStreamCaptureMode(&mode)); This guarantees that the current device is properly set before calling any resource-allocating cuda functions. This is the memory usage after the fix. The extra memory usages are gone. <img width="1242" height="459" alt="Image (1)" src="https://github.com/user-attachments/assets/4256e4c8-6f1d-4844-9f77-5b2935387df9" /> --------- Co-authored-by: Binyang Li <binyli@microsoft.com>
127 lines
3.8 KiB
C++
127 lines
3.8 KiB
C++
// Copyright (c) Microsoft Corporation.
|
|
// Licensed under the MIT license.
|
|
|
|
#include <atomic>
|
|
#include <mscclpp/core.hpp>
|
|
#include <mscclpp/gpu_utils.hpp>
|
|
#include <mscclpp/numa.hpp>
|
|
#include <mscclpp/proxy.hpp>
|
|
#include <mscclpp/utils.hpp>
|
|
#include <thread>
|
|
|
|
#include "api.h"
|
|
#include "debug.h"
|
|
|
|
namespace mscclpp {
|
|
|
|
constexpr int ProxyStopCheckPeriod = 1000;
|
|
constexpr int ProxyStartWarnPeriod = 1000;
|
|
|
|
struct Proxy::Impl {
|
|
ProxyHandler handler;
|
|
std::function<void()> threadInit;
|
|
std::shared_ptr<Fifo> fifo;
|
|
std::atomic_bool threadStarted;
|
|
std::thread service;
|
|
std::atomic_bool running;
|
|
|
|
Impl(ProxyHandler handler, std::function<void()> threadInit, int fifoSize)
|
|
: handler(handler),
|
|
threadInit(threadInit),
|
|
fifo(std::make_shared<Fifo>(fifoSize)),
|
|
threadStarted(false),
|
|
running(false) {}
|
|
};
|
|
|
|
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function<void()> threadInit, int fifoSize) {
|
|
pimpl_ = std::make_unique<Impl>(handler, threadInit, fifoSize);
|
|
}
|
|
|
|
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, int fifoSize) {
|
|
int cudaDevice;
|
|
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
|
|
int deviceNumaNode = getDeviceNumaNode(cudaDevice);
|
|
auto initFunc = [cudaDevice, deviceNumaNode]() {
|
|
MSCCLPP_CUDATHROW(cudaSetDevice(cudaDevice));
|
|
if (deviceNumaNode >= 0) {
|
|
numaBind(deviceNumaNode);
|
|
}
|
|
};
|
|
pimpl_ = std::make_unique<Impl>(handler, initFunc, fifoSize);
|
|
}
|
|
|
|
MSCCLPP_API_CPP Proxy::~Proxy() {
|
|
if (pimpl_) {
|
|
stop();
|
|
}
|
|
}
|
|
|
|
MSCCLPP_API_CPP void Proxy::start(bool blocking) {
|
|
pimpl_->running.store(true, std::memory_order_release);
|
|
pimpl_->service = std::thread([this] {
|
|
// threadInit() is responsible for setting up the runtime context for the thread.
|
|
// The default implementation sets the CUDA device and NUMA affinity to match the main thread (see Proxy ctor).
|
|
// It should be called before any CUDA API calls to avoid resource allocation on unwanted GPUs.
|
|
pimpl_->threadInit();
|
|
|
|
// never capture in a proxy thread
|
|
auto mode = cudaStreamCaptureModeRelaxed;
|
|
MSCCLPP_CUDATHROW(cudaThreadExchangeStreamCaptureMode(&mode));
|
|
|
|
pimpl_->threadStarted.store(true, std::memory_order_release);
|
|
|
|
ProxyHandler handler = this->pimpl_->handler;
|
|
auto fifo = this->pimpl_->fifo;
|
|
ProxyTrigger trigger;
|
|
|
|
int runCnt = ProxyStopCheckPeriod;
|
|
for (;;) {
|
|
if (runCnt-- == 0) {
|
|
runCnt = ProxyStopCheckPeriod;
|
|
if (!this->pimpl_->running.load(std::memory_order_acquire)) {
|
|
break;
|
|
}
|
|
}
|
|
// Poll to see if we are ready to send anything
|
|
trigger = fifo->poll();
|
|
if (trigger.fst == 0 || trigger.snd == 0) { // TODO: this check is a potential pitfall for custom triggers
|
|
continue; // there is one in progress
|
|
}
|
|
trigger.snd ^= (uint64_t{1} << uint64_t{63}); // this is where the last bit of snd is reverted.
|
|
|
|
ProxyHandlerResult result = handler(trigger);
|
|
|
|
// Send completion: reset only the high 64 bits
|
|
fifo->pop();
|
|
|
|
if (result == ProxyHandlerResult::Stop) {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
|
|
if (blocking) {
|
|
int count = ProxyStartWarnPeriod;
|
|
while (!pimpl_->threadStarted.load(std::memory_order_acquire)) {
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
|
count--;
|
|
if (count == 0) {
|
|
count = ProxyStartWarnPeriod;
|
|
WARN("Proxy thread startup taking longer than expected.");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
MSCCLPP_API_CPP void Proxy::stop() {
|
|
pimpl_->running.store(false, std::memory_order_release);
|
|
if (pimpl_->service.joinable()) {
|
|
pimpl_->service.join();
|
|
}
|
|
pimpl_->threadStarted.store(false, std::memory_order_release);
|
|
}
|
|
|
|
MSCCLPP_API_CPP std::shared_ptr<Fifo> Proxy::fifo() { return pimpl_->fifo; }
|
|
|
|
} // namespace mscclpp
|