mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
Update interface to let user change fifo size (#243)
Related with this issue: https://github.com/microsoft/mscclpp/issues/242. The user may use more threads than the number specified in `fifo_size` to interact with the FIFO. In this case, there will be unexpected behavior. Update the interface to let user change fifo size on their demands.
This commit is contained in:
@@ -12,12 +12,14 @@
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
constexpr size_t DEFAULT_FIFO_SIZE = 128;
|
||||
|
||||
/// A class representing a host proxy FIFO that can consume work elements pushed by device threads.
|
||||
class Fifo {
|
||||
public:
|
||||
/// Constructs a new @ref Fifo object.
|
||||
/// @param size The number of entires in the FIFO.
|
||||
Fifo(int size = 128);
|
||||
Fifo(int size = DEFAULT_FIFO_SIZE);
|
||||
|
||||
/// Destroys the @ref Fifo object.
|
||||
~Fifo();
|
||||
|
||||
@@ -22,8 +22,8 @@ using ProxyHandler = std::function<ProxyHandlerResult(ProxyTrigger)>;
|
||||
|
||||
class Proxy {
|
||||
public:
|
||||
Proxy(ProxyHandler handler, std::function<void()> threadInit);
|
||||
Proxy(ProxyHandler handler);
|
||||
Proxy(ProxyHandler handler, std::function<void()> threadInit, size_t fifoSize = DEFAULT_FIFO_SIZE);
|
||||
Proxy(ProxyHandler handler, size_t fifoSize = DEFAULT_FIFO_SIZE);
|
||||
~Proxy();
|
||||
|
||||
void start();
|
||||
@@ -41,4 +41,4 @@ class Proxy {
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_PROXY_HPP_
|
||||
#endif // MSCCLPP_PROXY_HPP_
|
||||
|
||||
@@ -26,7 +26,7 @@ class BaseProxyService {
|
||||
class ProxyService : public BaseProxyService {
|
||||
public:
|
||||
/// Constructor.
|
||||
ProxyService();
|
||||
ProxyService(size_t fifoSize = DEFAULT_FIFO_SIZE);
|
||||
|
||||
/// Build and add a semaphore to the proxy service.
|
||||
/// @param connection The connection associated with the semaphore.
|
||||
|
||||
@@ -21,7 +21,7 @@ void register_fifo(nb::module_& m) {
|
||||
});
|
||||
|
||||
nb::class_<Fifo>(m, "Fifo")
|
||||
.def(nb::init<int>(), nb::arg("size") = 128)
|
||||
.def(nb::init<int>(), nb::arg("size") = DEFAULT_FIFO_SIZE)
|
||||
.def("poll", &Fifo::poll)
|
||||
.def("pop", &Fifo::pop)
|
||||
.def("flush_tail", &Fifo::flushTail, nb::arg("sync") = false)
|
||||
|
||||
@@ -16,7 +16,7 @@ void register_proxy_channel(nb::module_& m) {
|
||||
.def("stop_proxy", &BaseProxyService::stopProxy);
|
||||
|
||||
nb::class_<ProxyService, BaseProxyService>(m, "ProxyService")
|
||||
.def(nb::init<>())
|
||||
.def(nb::init<size_t>(), nb::arg("fifoSize") = DEFAULT_FIFO_SIZE)
|
||||
.def("start_proxy", &ProxyService::startProxy)
|
||||
.def("stop_proxy", &ProxyService::stopProxy)
|
||||
.def("build_and_add_semaphore", &ProxyService::buildAndAddSemaphore, nb::arg("comm"), nb::arg("connection"))
|
||||
|
||||
12
src/proxy.cc
12
src/proxy.cc
@@ -25,15 +25,17 @@ struct Proxy::Impl {
|
||||
std::thread service;
|
||||
std::atomic_bool running;
|
||||
|
||||
Impl(ProxyHandler handler, std::function<void()> threadInit)
|
||||
: handler(handler), threadInit(threadInit), running(false) {}
|
||||
Impl(ProxyHandler handler, std::function<void()> threadInit, size_t fifoSize)
|
||||
: handler(handler), threadInit(threadInit), fifo(fifoSize), running(false) {}
|
||||
};
|
||||
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function<void()> threadInit) {
|
||||
pimpl = std::make_unique<Impl>(handler, threadInit);
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, std::function<void()> threadInit, size_t fifoSize) {
|
||||
pimpl = std::make_unique<Impl>(handler, threadInit, fifoSize);
|
||||
}
|
||||
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler) : Proxy(handler, [] {}) {}
|
||||
MSCCLPP_API_CPP Proxy::Proxy(ProxyHandler handler, size_t fifoSize)
|
||||
: Proxy(
|
||||
handler, [] {}, fifoSize) {}
|
||||
|
||||
MSCCLPP_API_CPP Proxy::~Proxy() {
|
||||
if (pimpl) {
|
||||
|
||||
@@ -16,9 +16,9 @@ MSCCLPP_API_CPP ProxyChannel::ProxyChannel(SemaphoreId semaphoreId, std::shared_
|
||||
MSCCLPP_API_CPP SimpleProxyChannel::SimpleProxyChannel(ProxyChannel proxyChan, MemoryId dst, MemoryId src)
|
||||
: proxyChan_(proxyChan), dst_(dst), src_(src) {}
|
||||
|
||||
MSCCLPP_API_CPP ProxyService::ProxyService()
|
||||
MSCCLPP_API_CPP ProxyService::ProxyService(size_t fifoSize)
|
||||
: proxy_(std::make_shared<Proxy>([&](ProxyTrigger triggerRaw) { return handleTrigger(triggerRaw); },
|
||||
[&]() { bindThread(); })) {
|
||||
[&]() { bindThread(); }, fifoSize)) {
|
||||
int cudaDevice;
|
||||
MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice));
|
||||
deviceNumaNode = getDeviceNumaNode(cudaDevice);
|
||||
|
||||
Reference in New Issue
Block a user