Update interface to let user change fifo size (#243)

Related with this issue:
https://github.com/microsoft/mscclpp/issues/242. The user may use more
threads than the number specified in `fifo_size` to interact with the
FIFO. In this case, there will be unexpected behavior.
Update the interface to let user change fifo size on their demands.
This commit is contained in:
Binyang Li
2024-01-09 22:14:36 -08:00
committed by GitHub
parent e7d3e2d44b
commit 163cba08c8
7 changed files with 18 additions and 14 deletions

View File

@@ -12,12 +12,14 @@
namespace mscclpp {
constexpr size_t DEFAULT_FIFO_SIZE = 128;
/// A class representing a host proxy FIFO that can consume work elements pushed by device threads.
class Fifo {
public:
/// Constructs a new @ref Fifo object.
/// @param size The number of entires in the FIFO.
Fifo(int size = 128);
Fifo(int size = DEFAULT_FIFO_SIZE);
/// Destroys the @ref Fifo object.
~Fifo();

View File

@@ -22,8 +22,8 @@ using ProxyHandler = std::function<ProxyHandlerResult(ProxyTrigger)>;
class Proxy {
public:
Proxy(ProxyHandler handler, std::function<void()> threadInit);
Proxy(ProxyHandler handler);
Proxy(ProxyHandler handler, std::function<void()> threadInit, size_t fifoSize = DEFAULT_FIFO_SIZE);
Proxy(ProxyHandler handler, size_t fifoSize = DEFAULT_FIFO_SIZE);
~Proxy();
void start();
@@ -41,4 +41,4 @@ class Proxy {
} // namespace mscclpp
#endif // MSCCLPP_PROXY_HPP_
#endif // MSCCLPP_PROXY_HPP_

View File

@@ -26,7 +26,7 @@ class BaseProxyService {
class ProxyService : public BaseProxyService {
public:
/// Constructor.
ProxyService();
ProxyService(size_t fifoSize = DEFAULT_FIFO_SIZE);
/// Build and add a semaphore to the proxy service.
/// @param connection The connection associated with the semaphore.

View File

@@ -21,7 +21,7 @@ void register_fifo(nb::module_& m) {
});
nb::class_<Fifo>(m, "Fifo")
.def(nb::init<int>(), nb::arg("size") = 128)
.def(nb::init<int>(), nb::arg("size") = DEFAULT_FIFO_SIZE)
.def("poll", &Fifo::poll)
.def("pop", &Fifo::pop)
.def("flush_tail", &Fifo::flushTail, nb::arg("sync") = false)

View File

@@ -16,7 +16,7 @@ void register_proxy_channel(nb::module_& m) {
.def("stop_proxy", &BaseProxyService::stopProxy);
nb::class_<ProxyService, BaseProxyService>(m, "ProxyService")
.def(nb::init<>())
.def(nb::init<size_t>(), nb::arg("fifoSize") = DEFAULT_FIFO_SIZE)
.def("start_proxy", &ProxyService::startProxy)
.def("stop_proxy", &ProxyService::stopProxy)
.def("build_and_add_semaphore", &ProxyService::buildAndAddSemaphore, nb::arg("comm"), nb::arg("connection"))

View File

@@ -25,15 +25,17 @@ struct Proxy::Impl {
std::thread service;
std::atomic_bool running;
Impl(ProxyHandler handler, std::function<void()> threadInit)
: handler(handler), threadInit(threadInit), running(false) {}
Impl(ProxyHandler handler, std::function<void()> threadInit, size_t fifoSize)
: handler(handler), threadInit(threadInit), fifo(fifoSize), running(false) {}
};
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function<void()> threadInit) {
pimpl = std::make_unique<Impl>(handler, threadInit);
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function<void()> threadInit, size_t fifoSize) {
pimpl = std::make_unique<Impl>(handler, threadInit, fifoSize);
}
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) : Proxy(handler, [] {}) {}
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, size_t fifoSize)
: Proxy(
handler, [] {}, fifoSize) {}
MSCCLPP_API_CPP Proxy::~Proxy() {
if (pimpl) {

View File

@@ -16,9 +16,9 @@ MSCCLPP_API_CPP ProxyChannel::ProxyChannel(SemaphoreId semaphoreId, std::shared_
MSCCLPP_API_CPP SimpleProxyChannel::SimpleProxyChannel(ProxyChannel proxyChan, MemoryId dst, MemoryId src)
: proxyChan_(proxyChan), dst_(dst), src_(src) {}
MSCCLPP_API_CPP ProxyService::ProxyService()
MSCCLPP_API_CPP ProxyService::ProxyService(size_t fifoSize)
: proxy_(std::make_shared<Proxy>([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); },
[&]() { bindThread(); })) {
[&]() { bindThread(); }, fifoSize)) {
int cudaDevice;
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
deviceNumaNode = getDeviceNumaNode(cudaDevice);