rebase and fix

This commit is contained in:
Crutcher Dunnavant
2023-03-24 22:23:20 +00:00
parent 57b3c36975
commit 8b6e35d5e0
3 changed files with 1 additions and 90 deletions

View File

@@ -105,33 +105,6 @@ NB_MODULE(_py_mscclpp, m) {
m.attr("MSCCLPP_UNIQUE_ID_BYTES") = MSCCLPP_UNIQUE_ID_BYTES;
nb::enum_<mscclppRedOp_t>(m, "reduce_op")
.value("sum", mscclppRedOp_t::mscclppSum)
.value("prod", mscclppRedOp_t::mscclppProd)
.value("max", mscclppRedOp_t::mscclppMax)
.value("min", mscclppRedOp_t::mscclppMin)
.value("avg", mscclppRedOp_t::mscclppAvg);
nb::enum_<mscclppDataType_t>(m, "dtype")
.value("int8", mscclppDataType_t::mscclppInt8)
.value("char", mscclppDataType_t::mscclppChar)
.value("uint8", mscclppDataType_t::mscclppUint8)
.value("int32", mscclppDataType_t::mscclppInt32)
.value("uint32", mscclppDataType_t::mscclppUint32)
.value("int", mscclppDataType_t::mscclppInt)
.value("int64", mscclppDataType_t::mscclppInt64)
.value("uint64", mscclppDataType_t::mscclppUint64)
.value("float16", mscclppDataType_t::mscclppFloat16)
.value("half", mscclppDataType_t::mscclppHalf)
.value("float32", mscclppDataType_t::mscclppFloat32)
.value("float", mscclppDataType_t::mscclppFloat)
.value("float64", mscclppDataType_t::mscclppFloat64)
.value("double", mscclppDataType_t::mscclppDouble)
#if defined(__CUDA_BF16_TYPES_EXIST__)
.value("bfloat16", mscclppDataType_t::mscclppBfloat16)
#endif
;
nb::class_<mscclppUniqueId>(m, "MscclppUniqueId")
.def_ro_static("__doc__", &DOC_MscclppUniqueId)
.def_static(
@@ -172,7 +145,7 @@ NB_MODULE(_py_mscclpp, m) {
comm._is_open = true;
return maybe(
mscclppCommInitRank(
&comm._handle, world_size, rank, address.c_str()),
&comm._handle, world_size, address.c_str(), rank),
comm,
"Failed to initialize comms: %s rank=%d world_size=%d",
address,

View File

@@ -4,13 +4,8 @@ __all__ = (
"MscclppUniqueId",
"MSCCLPP_UNIQUE_ID_BYTES",
"MscclppComm",
"dtype",
"reduce_op",
)
dtype = _py_mscclpp.dtype
reduce_op = _py_mscclpp.reduce_op
MscclppUniqueId = _py_mscclpp.MscclppUniqueId
MSCCLPP_UNIQUE_ID_BYTES = _py_mscclpp.MSCCLPP_UNIQUE_ID_BYTES

View File

@@ -3,63 +3,6 @@ import hamcrest
import mscclpp
class DTypeTest(unittest.TestCase):
def test(self) -> None:
for name, val in [
('int8', 0),
('char', 0),
('uint8', 1),
('int32', 2),
('int', 2),
('uint32', 3),
('int64', 4),
('uint64', 5),
('float16', 6),
('half', 6),
('float32', 7),
('float', 7),
('float64', 8),
('double', 8),
]:
try:
dtype = getattr(mscclpp.dtype, name)
hamcrest.assert_that(
mscclpp.dtype(val),
hamcrest.equal_to(dtype),
reason=(name, val),
)
hamcrest.assert_that(
int(mscclpp.dtype(val)),
hamcrest.equal_to(val),
reason=(name, val),
)
except Exception as e:
raise AssertionError((name, val)) from e
class ReduceOpTest(unittest.TestCase):
def test(self) -> None:
for name, val in [
('sum', 0),
('prod', 1),
('max', 2),
('min', 3),
('avg', 4),
]:
try:
dtype = getattr(mscclpp.reduce_op, name)
hamcrest.assert_that(
mscclpp.reduce_op(val),
hamcrest.equal_to(dtype),
reason=(name, val),
)
hamcrest.assert_that(
int(mscclpp.reduce_op(val)),
hamcrest.equal_to(val),
reason=(name, val),
)
except Exception as e:
raise AssertionError((name, val)) from e
class UniqueIdTest(unittest.TestCase):
def test_no_constructor(self) -> None: