Fix Python binding of exceptions (#444)

* Fixed errors to be catchable from Python code
* Skip IB tests in Python unit tests when IB ports are down
This commit is contained in:
Changho Hwang
2025-01-09 11:58:23 -08:00
committed by GitHub
parent 80abce59ef
commit f2b52c6318
8 changed files with 66 additions and 41 deletions

View File

@@ -4,6 +4,13 @@
import os as _os
from ._mscclpp import (
ErrorCode,
BaseError,
Error,
SysError,
CudaError,
CuError,
IbError,
Communicator,
Connection,
connect_nvls_collective,

View File

@@ -9,7 +9,19 @@
namespace nb = nanobind;
using namespace mscclpp;
void register_error(nb::module_& m) {
#define REGISTER_EXCEPTION_TRANSLATOR(name_) \
nb::register_exception_translator( \
[](const std::exception_ptr &p, void *payload) { \
try { \
std::rethrow_exception(p); \
} catch (const name_ &e) { \
PyErr_SetObject(reinterpret_cast<PyObject *>(payload), \
PyTuple_Pack(2, PyLong_FromLong(long(e.getErrorCode())), PyUnicode_FromString(e.what()))); \
} \
}, \
m.attr(#name_).ptr());
void register_error(nb::module_ &m) {
nb::enum_<ErrorCode>(m, "ErrorCode")
.value("SystemError", ErrorCode::SystemError)
.value("InternalError", ErrorCode::InternalError)
@@ -19,24 +31,21 @@ void register_error(nb::module_& m) {
.value("Aborted", ErrorCode::Aborted)
.value("ExecutorError", ErrorCode::ExecutorError);
nb::class_<BaseError>(m, "BaseError")
.def(nb::init<std::string&, int>(), nb::arg("message"), nb::arg("errorCode"))
.def("get_error_code", &BaseError::getErrorCode)
.def("what", &BaseError::what);
nb::exception<BaseError>(m, "BaseError");
REGISTER_EXCEPTION_TRANSLATOR(BaseError);
nb::class_<Error, BaseError>(m, "Error")
.def(nb::init<const std::string&, ErrorCode>(), nb::arg("message"), nb::arg("errorCode"))
.def("get_error_code", &Error::getErrorCode);
nb::exception<Error>(m, "Error", m.attr("BaseError").ptr());
REGISTER_EXCEPTION_TRANSLATOR(Error);
nb::class_<SysError, BaseError>(m, "SysError")
.def(nb::init<const std::string&, int>(), nb::arg("message"), nb::arg("errorCode"));
nb::exception<SysError>(m, "SysError", m.attr("BaseError").ptr());
REGISTER_EXCEPTION_TRANSLATOR(SysError);
nb::class_<CudaError, BaseError>(m, "CudaError")
.def(nb::init<const std::string&, int>(), nb::arg("message"), nb::arg("errorCode"));
nb::exception<CudaError>(m, "CudaError", m.attr("BaseError").ptr());
REGISTER_EXCEPTION_TRANSLATOR(CudaError);
nb::class_<CuError, BaseError>(m, "CuError")
.def(nb::init<const std::string&, int>(), nb::arg("message"), nb::arg("errorCode"));
nb::exception<CuError>(m, "CuError", m.attr("BaseError").ptr());
REGISTER_EXCEPTION_TRANSLATOR(CuError);
nb::class_<IbError, BaseError>(m, "IbError")
.def(nb::init<const std::string&, int>(), nb::arg("message"), nb::arg("errorCode"));
nb::exception<IbError>(m, "IbError", m.attr("BaseError").ptr());
REGISTER_EXCEPTION_TRANSLATOR(IbError);
}

View File

@@ -12,6 +12,8 @@ import netifaces as ni
import pytest
from mscclpp import (
ErrorCode,
Error,
DataType,
EndpointConfig,
ExecutionPlan,
@@ -44,7 +46,7 @@ def all_ranks_on_the_same_node(mpi_group: MpiGroup):
@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("ifIpPortTrio", ["eth0:localhost:50000", ethernet_interface_name, ""])
@pytest.mark.parametrize("ifIpPortTrio", [f"{ethernet_interface_name}:localhost:50000", ethernet_interface_name, ""])
def test_group_with_ip(mpi_group: MpiGroup, ifIpPortTrio: str):
if (ethernet_interface_name in ni.interfaces()) is False:
pytest.skip(f"{ethernet_interface_name} is not an interface to use on this node")
@@ -146,7 +148,12 @@ def create_group_and_connection(mpi_group: MpiGroup, transport: str):
if (transport == "NVLink" or transport == "NVLS") and all_ranks_on_the_same_node(mpi_group) is False:
pytest.skip("cannot use nvlink/nvls for cross node")
group = mscclpp_comm.CommGroup(mpi_group.comm)
connection = create_connection(group, transport)
try:
connection = create_connection(group, transport)
except Error as e:
if transport == "IB" and e.args[0] == ErrorCode.InvalidUsage:
pytest.skip("IB not supported on this node")
raise
return group, connection