mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
Fix Python binding of exceptions (#444)
* Fixed errors to be catchable from Python code * Skip IB tests in Python unit tests when IB ports are down
This commit is contained in:
@@ -13,13 +13,10 @@
|
||||
#include <bitset>
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <mscclpp/gpu.hpp>
|
||||
#include <mscclpp/gpu_utils.hpp>
|
||||
#include <mscclpp/errors.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "errors.hpp"
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
#define MSCCLPP_UNIQUE_ID_BYTES 128
|
||||
@@ -303,23 +300,6 @@ inline TransportFlags operator^(Transport transport1, Transport transport2) {
|
||||
return TransportFlags(transport1) ^ transport2;
|
||||
}
|
||||
|
||||
/// Get the number of available InfiniBand devices.
|
||||
///
|
||||
/// @return The number of available InfiniBand devices.
|
||||
int getIBDeviceCount();
|
||||
|
||||
/// Get the name of the InfiniBand device associated with the specified transport.
|
||||
///
|
||||
/// @param ibTransport The InfiniBand transport to get the device name for.
|
||||
/// @return The name of the InfiniBand device associated with the specified transport.
|
||||
std::string getIBDeviceName(Transport ibTransport);
|
||||
|
||||
/// Get the InfiniBand transport associated with the specified device name.
|
||||
///
|
||||
/// @param ibDeviceName The name of the InfiniBand device to get the transport for.
|
||||
/// @return The InfiniBand transport associated with the specified device name.
|
||||
Transport getIBTransportByDeviceName(const std::string& ibDeviceName);
|
||||
|
||||
class Context;
|
||||
class Connection;
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <memory>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/gpu.hpp>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mscclpp {
|
||||
|
||||
@@ -193,6 +193,9 @@ void gpuMemcpy(T* dst, const T* src, size_t nelems, cudaMemcpyKind kind = cudaMe
|
||||
detail::gpuMemcpy(dst, src, nelems * sizeof(T), kind);
|
||||
}
|
||||
|
||||
/// Check if NVLink SHARP (NVLS) is supported.
|
||||
///
|
||||
/// @return True if NVLink SHARP (NVLS) is supported, false otherwise.
|
||||
bool isNvlsSupported();
|
||||
|
||||
/// Allocates a GPU memory space specialized for communication. The memory is zeroed out. Get the device pointer by
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#define MSCCLPP_UTILS_HPP_
|
||||
|
||||
#include <chrono>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <string>
|
||||
|
||||
namespace mscclpp {
|
||||
@@ -37,6 +38,23 @@ struct ScopedTimer : public Timer {
|
||||
|
||||
std::string getHostName(int maxlen, const char delim);
|
||||
|
||||
/// Get the number of available InfiniBand devices.
|
||||
///
|
||||
/// @return The number of available InfiniBand devices.
|
||||
int getIBDeviceCount();
|
||||
|
||||
/// Get the name of the InfiniBand device associated with the specified transport.
|
||||
///
|
||||
/// @param ibTransport The InfiniBand transport to get the device name for.
|
||||
/// @return The name of the InfiniBand device associated with the specified transport.
|
||||
std::string getIBDeviceName(Transport ibTransport);
|
||||
|
||||
/// Get the InfiniBand transport associated with the specified device name.
|
||||
///
|
||||
/// @param ibDeviceName The name of the InfiniBand device to get the transport for.
|
||||
/// @return The InfiniBand transport associated with the specified device name.
|
||||
Transport getIBTransportByDeviceName(const std::string& ibDeviceName);
|
||||
|
||||
} // namespace mscclpp
|
||||
|
||||
#endif // MSCCLPP_UTILS_HPP_
|
||||
|
||||
@@ -4,6 +4,13 @@
|
||||
import os as _os
|
||||
|
||||
from ._mscclpp import (
|
||||
ErrorCode,
|
||||
BaseError,
|
||||
Error,
|
||||
SysError,
|
||||
CudaError,
|
||||
CuError,
|
||||
IbError,
|
||||
Communicator,
|
||||
Connection,
|
||||
connect_nvls_collective,
|
||||
|
||||
@@ -9,7 +9,19 @@
|
||||
namespace nb = nanobind;
|
||||
using namespace mscclpp;
|
||||
|
||||
void register_error(nb::module_& m) {
|
||||
#define REGISTER_EXCEPTION_TRANSLATOR(name_) \
|
||||
nb::register_exception_translator( \
|
||||
[](const std::exception_ptr &p, void *payload) { \
|
||||
try { \
|
||||
std::rethrow_exception(p); \
|
||||
} catch (const name_ &e) { \
|
||||
PyErr_SetObject(reinterpret_cast<PyObject *>(payload), \
|
||||
PyTuple_Pack(2, PyLong_FromLong(long(e.getErrorCode())), PyUnicode_FromString(e.what()))); \
|
||||
} \
|
||||
}, \
|
||||
m.attr(#name_).ptr());
|
||||
|
||||
void register_error(nb::module_ &m) {
|
||||
nb::enum_<ErrorCode>(m, "ErrorCode")
|
||||
.value("SystemError", ErrorCode::SystemError)
|
||||
.value("InternalError", ErrorCode::InternalError)
|
||||
@@ -19,24 +31,21 @@ void register_error(nb::module_& m) {
|
||||
.value("Aborted", ErrorCode::Aborted)
|
||||
.value("ExecutorError", ErrorCode::ExecutorError);
|
||||
|
||||
nb::class_<BaseError>(m, "BaseError")
|
||||
.def(nb::init<std::string&, int>(), nb::arg("message"), nb::arg("errorCode"))
|
||||
.def("get_error_code", &BaseError::getErrorCode)
|
||||
.def("what", &BaseError::what);
|
||||
nb::exception<BaseError>(m, "BaseError");
|
||||
REGISTER_EXCEPTION_TRANSLATOR(BaseError);
|
||||
|
||||
nb::class_<Error, BaseError>(m, "Error")
|
||||
.def(nb::init<const std::string&, ErrorCode>(), nb::arg("message"), nb::arg("errorCode"))
|
||||
.def("get_error_code", &Error::getErrorCode);
|
||||
nb::exception<Error>(m, "Error", m.attr("BaseError").ptr());
|
||||
REGISTER_EXCEPTION_TRANSLATOR(Error);
|
||||
|
||||
nb::class_<SysError, BaseError>(m, "SysError")
|
||||
.def(nb::init<const std::string&, int>(), nb::arg("message"), nb::arg("errorCode"));
|
||||
nb::exception<SysError>(m, "SysError", m.attr("BaseError").ptr());
|
||||
REGISTER_EXCEPTION_TRANSLATOR(SysError);
|
||||
|
||||
nb::class_<CudaError, BaseError>(m, "CudaError")
|
||||
.def(nb::init<const std::string&, int>(), nb::arg("message"), nb::arg("errorCode"));
|
||||
nb::exception<CudaError>(m, "CudaError", m.attr("BaseError").ptr());
|
||||
REGISTER_EXCEPTION_TRANSLATOR(CudaError);
|
||||
|
||||
nb::class_<CuError, BaseError>(m, "CuError")
|
||||
.def(nb::init<const std::string&, int>(), nb::arg("message"), nb::arg("errorCode"));
|
||||
nb::exception<CuError>(m, "CuError", m.attr("BaseError").ptr());
|
||||
REGISTER_EXCEPTION_TRANSLATOR(CuError);
|
||||
|
||||
nb::class_<IbError, BaseError>(m, "IbError")
|
||||
.def(nb::init<const std::string&, int>(), nb::arg("message"), nb::arg("errorCode"));
|
||||
nb::exception<IbError>(m, "IbError", m.attr("BaseError").ptr());
|
||||
REGISTER_EXCEPTION_TRANSLATOR(IbError);
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ import netifaces as ni
|
||||
import pytest
|
||||
|
||||
from mscclpp import (
|
||||
ErrorCode,
|
||||
Error,
|
||||
DataType,
|
||||
EndpointConfig,
|
||||
ExecutionPlan,
|
||||
@@ -44,7 +46,7 @@ def all_ranks_on_the_same_node(mpi_group: MpiGroup):
|
||||
|
||||
|
||||
@parametrize_mpi_groups(2, 4, 8, 16)
|
||||
@pytest.mark.parametrize("ifIpPortTrio", ["eth0:localhost:50000", ethernet_interface_name, ""])
|
||||
@pytest.mark.parametrize("ifIpPortTrio", [f"{ethernet_interface_name}:localhost:50000", ethernet_interface_name, ""])
|
||||
def test_group_with_ip(mpi_group: MpiGroup, ifIpPortTrio: str):
|
||||
if (ethernet_interface_name in ni.interfaces()) is False:
|
||||
pytest.skip(f"{ethernet_interface_name} is not an interface to use on this node")
|
||||
@@ -146,7 +148,12 @@ def create_group_and_connection(mpi_group: MpiGroup, transport: str):
|
||||
if (transport == "NVLink" or transport == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False:
|
||||
pytest.skip("cannot use nvlink/nvls for cross node")
|
||||
group = mscclpp_comm.CommGroup(mpi_group.comm)
|
||||
connection = create_connection(group, transport)
|
||||
try:
|
||||
connection = create_connection(group, transport)
|
||||
except Error as e:
|
||||
if transport == "IB" and e.args[0] == ErrorCode.InvalidUsage:
|
||||
pytest.skip("IB not supported on this node")
|
||||
raise
|
||||
return group, connection
|
||||
|
||||
|
||||
|
||||
@@ -367,10 +367,10 @@ IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRec
|
||||
if (port == -1) {
|
||||
port = this->getAnyActivePort();
|
||||
if (port == -1) {
|
||||
throw mscclpp::Error("No active port found", ErrorCode::InternalError);
|
||||
throw mscclpp::Error("No active port found", ErrorCode::InvalidUsage);
|
||||
}
|
||||
} else if (!this->isPortUsable(port)) {
|
||||
throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InternalError);
|
||||
throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage);
|
||||
}
|
||||
qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend));
|
||||
return qps.back().get();
|
||||
|
||||
Reference in New Issue
Block a user