mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-20 06:49:29 +00:00
include left out lib; add enums
This commit is contained in:
@@ -99,11 +99,39 @@ static const std::string DOC_MscclppUniqueId =
|
||||
|
||||
static const std::string DOC_MscclppComm = "MSCCLPP Communications Handle";
|
||||
|
||||
|
||||
NB_MODULE(_py_mscclpp, m) {
|
||||
m.doc() = "Python bindings for MSCCLPP: which is not NCCL";
|
||||
|
||||
m.attr("MSCCLPP_UNIQUE_ID_BYTES") = MSCCLPP_UNIQUE_ID_BYTES;
|
||||
|
||||
nb::enum_<mscclppRedOp_t>(m, "reduce_op")
|
||||
.value("sum", mscclppRedOp_t::mscclppSum)
|
||||
.value("prod", mscclppRedOp_t::mscclppProd)
|
||||
.value("max", mscclppRedOp_t::mscclppMax)
|
||||
.value("min", mscclppRedOp_t::mscclppMin)
|
||||
.value("avg", mscclppRedOp_t::mscclppAvg);
|
||||
|
||||
nb::enum_<mscclppDataType_t>(m, "dtype")
|
||||
.value("int8", mscclppDataType_t::mscclppInt8)
|
||||
.value("char", mscclppDataType_t::mscclppChar)
|
||||
.value("uint8", mscclppDataType_t::mscclppUint8)
|
||||
.value("int32", mscclppDataType_t::mscclppInt32)
|
||||
.value("uint32", mscclppDataType_t::mscclppUint32)
|
||||
.value("int", mscclppDataType_t::mscclppInt)
|
||||
.value("int64", mscclppDataType_t::mscclppInt64)
|
||||
.value("uint64", mscclppDataType_t::mscclppUint64)
|
||||
.value("float16", mscclppDataType_t::mscclppFloat16)
|
||||
.value("half", mscclppDataType_t::mscclppHalf)
|
||||
.value("float32", mscclppDataType_t::mscclppFloat32)
|
||||
.value("float", mscclppDataType_t::mscclppFloat)
|
||||
.value("float64", mscclppDataType_t::mscclppFloat64)
|
||||
.value("double", mscclppDataType_t::mscclppDouble)
|
||||
#if defined(__CUDA_BF16_TYPES_EXIST__)
|
||||
.value("bfloat16", mscclppDataType_t::mscclppBfloat16)
|
||||
#endif
|
||||
;
|
||||
|
||||
nb::class_<mscclppUniqueId>(m, "MscclppUniqueId")
|
||||
.def_ro_static("__doc__", &DOC_MscclppUniqueId)
|
||||
.def_static(
|
||||
|
||||
18
python/src/mscclpp/__init__.py
Normal file
18
python/src/mscclpp/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from . import _py_mscclpp
|
||||
|
||||
__all__ = (
|
||||
"MscclppUniqueId",
|
||||
"MSCCLPP_UNIQUE_ID_BYTES",
|
||||
"MscclppComm",
|
||||
"dtype",
|
||||
"reduce_op",
|
||||
)
|
||||
|
||||
dtype = _py_mscclpp.dtype
|
||||
reduce_op = _py_mscclpp.reduce_op
|
||||
|
||||
MscclppUniqueId = _py_mscclpp.MscclppUniqueId
|
||||
MSCCLPP_UNIQUE_ID_BYTES = _py_mscclpp.MSCCLPP_UNIQUE_ID_BYTES
|
||||
|
||||
MscclppComm = _py_mscclpp.MscclppComm
|
||||
|
||||
107
python/src/mscclpp/test_mscclpp.py
Normal file
107
python/src/mscclpp/test_mscclpp.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import unittest
|
||||
import hamcrest
|
||||
|
||||
import mscclpp
|
||||
|
||||
class DTypeTest(unittest.TestCase):
|
||||
def test(self) -> None:
|
||||
for name, val in [
|
||||
('int8', 0),
|
||||
('char', 0),
|
||||
('uint8', 1),
|
||||
('int32', 2),
|
||||
('int', 2),
|
||||
('uint32', 3),
|
||||
('int64', 4),
|
||||
('uint64', 5),
|
||||
('float16', 6),
|
||||
('half', 6),
|
||||
('float32', 7),
|
||||
('float', 7),
|
||||
('float64', 8),
|
||||
('double', 8),
|
||||
]:
|
||||
try:
|
||||
dtype = getattr(mscclpp.dtype, name)
|
||||
hamcrest.assert_that(
|
||||
mscclpp.dtype(val),
|
||||
hamcrest.equal_to(dtype),
|
||||
reason=(name, val),
|
||||
)
|
||||
hamcrest.assert_that(
|
||||
int(mscclpp.dtype(val)),
|
||||
hamcrest.equal_to(val),
|
||||
reason=(name, val),
|
||||
)
|
||||
except Exception as e:
|
||||
raise AssertionError((name, val)) from e
|
||||
|
||||
class ReduceOpTest(unittest.TestCase):
|
||||
def test(self) -> None:
|
||||
for name, val in [
|
||||
('sum', 0),
|
||||
('prod', 1),
|
||||
('max', 2),
|
||||
('min', 3),
|
||||
('avg', 4),
|
||||
]:
|
||||
try:
|
||||
dtype = getattr(mscclpp.reduce_op, name)
|
||||
hamcrest.assert_that(
|
||||
mscclpp.reduce_op(val),
|
||||
hamcrest.equal_to(dtype),
|
||||
reason=(name, val),
|
||||
)
|
||||
hamcrest.assert_that(
|
||||
int(mscclpp.reduce_op(val)),
|
||||
hamcrest.equal_to(val),
|
||||
reason=(name, val),
|
||||
)
|
||||
except Exception as e:
|
||||
raise AssertionError((name, val)) from e
|
||||
|
||||
|
||||
class UniqueIdTest(unittest.TestCase):
|
||||
def test_no_constructor(self) -> None:
|
||||
hamcrest.assert_that(
|
||||
hamcrest.calling(mscclpp.MscclppUniqueId).with_args(),
|
||||
hamcrest.raises(
|
||||
TypeError,
|
||||
"no constructor",
|
||||
),
|
||||
)
|
||||
|
||||
def test_getUniqueId(self) -> None:
|
||||
myId = mscclpp.MscclppUniqueId.from_context()
|
||||
|
||||
hamcrest.assert_that(
|
||||
myId.bytes(),
|
||||
hamcrest.has_length(mscclpp.MSCCLPP_UNIQUE_ID_BYTES),
|
||||
)
|
||||
|
||||
# from_bytes should work
|
||||
copy = mscclpp.MscclppUniqueId.from_bytes(myId.bytes())
|
||||
hamcrest.assert_that(
|
||||
copy.bytes(),
|
||||
hamcrest.equal_to(myId.bytes()),
|
||||
)
|
||||
|
||||
# bad size
|
||||
hamcrest.assert_that(
|
||||
hamcrest.calling(mscclpp.MscclppUniqueId.from_bytes).with_args(b'abc'),
|
||||
hamcrest.raises(
|
||||
ValueError,
|
||||
f"Requires exactly {mscclpp.MSCCLPP_UNIQUE_ID_BYTES} bytes; found 3"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CommsTest(unittest.TestCase):
|
||||
def _test(self) -> None:
|
||||
# this hangs forever
|
||||
comm = mscclpp.MscclppComm.init_rank_from_address(
|
||||
address="127.0.0.1:50000",
|
||||
rank=0,
|
||||
world_size=2,
|
||||
)
|
||||
comm.close()
|
||||
Reference in New Issue
Block a user