diff --git a/python/src/_py_mscclpp.cpp b/python/src/_py_mscclpp.cpp index 32c75566..55c5848f 100644 --- a/python/src/_py_mscclpp.cpp +++ b/python/src/_py_mscclpp.cpp @@ -105,33 +105,6 @@ NB_MODULE(_py_mscclpp, m) { m.attr("MSCCLPP_UNIQUE_ID_BYTES") = MSCCLPP_UNIQUE_ID_BYTES; - nb::enum_(m, "reduce_op") - .value("sum", mscclppRedOp_t::mscclppSum) - .value("prod", mscclppRedOp_t::mscclppProd) - .value("max", mscclppRedOp_t::mscclppMax) - .value("min", mscclppRedOp_t::mscclppMin) - .value("avg", mscclppRedOp_t::mscclppAvg); - - nb::enum_(m, "dtype") - .value("int8", mscclppDataType_t::mscclppInt8) - .value("char", mscclppDataType_t::mscclppChar) - .value("uint8", mscclppDataType_t::mscclppUint8) - .value("int32", mscclppDataType_t::mscclppInt32) - .value("uint32", mscclppDataType_t::mscclppUint32) - .value("int", mscclppDataType_t::mscclppInt) - .value("int64", mscclppDataType_t::mscclppInt64) - .value("uint64", mscclppDataType_t::mscclppUint64) - .value("float16", mscclppDataType_t::mscclppFloat16) - .value("half", mscclppDataType_t::mscclppHalf) - .value("float32", mscclppDataType_t::mscclppFloat32) - .value("float", mscclppDataType_t::mscclppFloat) - .value("float64", mscclppDataType_t::mscclppFloat64) - .value("double", mscclppDataType_t::mscclppDouble) -#if defined(__CUDA_BF16_TYPES_EXIST__) - .value("bfloat16", mscclppDataType_t::mscclppBfloat16) -#endif - ; - nb::class_(m, "MscclppUniqueId") .def_ro_static("__doc__", &DOC_MscclppUniqueId) .def_static( @@ -172,7 +145,7 @@ NB_MODULE(_py_mscclpp, m) { comm._is_open = true; return maybe( mscclppCommInitRank( - &comm._handle, world_size, rank, address.c_str()), + &comm._handle, world_size, address.c_str(), rank), comm, "Failed to initialize comms: %s rank=%d world_size=%d", address, diff --git a/python/src/mscclpp/__init__.py b/python/src/mscclpp/__init__.py index 8778c097..e825b92d 100644 --- a/python/src/mscclpp/__init__.py +++ b/python/src/mscclpp/__init__.py @@ -4,13 +4,8 @@ __all__ = ( "MscclppUniqueId", "MSCCLPP_UNIQUE_ID_BYTES", "MscclppComm", - "dtype", - "reduce_op", ) -dtype = _py_mscclpp.dtype -reduce_op = _py_mscclpp.reduce_op - MscclppUniqueId = _py_mscclpp.MscclppUniqueId MSCCLPP_UNIQUE_ID_BYTES = _py_mscclpp.MSCCLPP_UNIQUE_ID_BYTES diff --git a/python/src/mscclpp/test_mscclpp.py b/python/src/mscclpp/test_mscclpp.py index 864ae85e..e67f2770 100644 --- a/python/src/mscclpp/test_mscclpp.py +++ b/python/src/mscclpp/test_mscclpp.py @@ -3,63 +3,6 @@ import hamcrest import mscclpp -class DTypeTest(unittest.TestCase): - def test(self) -> None: - for name, val in [ - ('int8', 0), - ('char', 0), - ('uint8', 1), - ('int32', 2), - ('int', 2), - ('uint32', 3), - ('int64', 4), - ('uint64', 5), - ('float16', 6), - ('half', 6), - ('float32', 7), - ('float', 7), - ('float64', 8), - ('double', 8), - ]: - try: - dtype = getattr(mscclpp.dtype, name) - hamcrest.assert_that( - mscclpp.dtype(val), - hamcrest.equal_to(dtype), - reason=(name, val), - ) - hamcrest.assert_that( - int(mscclpp.dtype(val)), - hamcrest.equal_to(val), - reason=(name, val), - ) - except Exception as e: - raise AssertionError((name, val)) from e - -class ReduceOpTest(unittest.TestCase): - def test(self) -> None: - for name, val in [ - ('sum', 0), - ('prod', 1), - ('max', 2), - ('min', 3), - ('avg', 4), - ]: - try: - dtype = getattr(mscclpp.reduce_op, name) - hamcrest.assert_that( - mscclpp.reduce_op(val), - hamcrest.equal_to(dtype), - reason=(name, val), - ) - hamcrest.assert_that( - int(mscclpp.reduce_op(val)), - hamcrest.equal_to(val), - reason=(name, val), - ) - except Exception as e: - raise AssertionError((name, val)) from e - class UniqueIdTest(unittest.TestCase): def test_no_constructor(self) -> None: