include left out lib; add enums

This commit is contained in:
Crutcher Dunnavant
2023-03-24 21:24:00 +00:00
parent e181cca064
commit 57b3c36975
3 changed files with 153 additions and 0 deletions

View File

@@ -99,11 +99,39 @@ static const std::string DOC_MscclppUniqueId =
static const std::string DOC_MscclppComm = "MSCCLPP Communications Handle";
NB_MODULE(_py_mscclpp, m) {
m.doc() = "Python bindings for MSCCLPP: which is not NCCL";
m.attr("MSCCLPP_UNIQUE_ID_BYTES") = MSCCLPP_UNIQUE_ID_BYTES;
nb::enum_<mscclppRedOp_t>(m, "reduce_op")
.value("sum", mscclppRedOp_t::mscclppSum)
.value("prod", mscclppRedOp_t::mscclppProd)
.value("max", mscclppRedOp_t::mscclppMax)
.value("min", mscclppRedOp_t::mscclppMin)
.value("avg", mscclppRedOp_t::mscclppAvg);
nb::enum_<mscclppDataType_t>(m, "dtype")
.value("int8", mscclppDataType_t::mscclppInt8)
.value("char", mscclppDataType_t::mscclppChar)
.value("uint8", mscclppDataType_t::mscclppUint8)
.value("int32", mscclppDataType_t::mscclppInt32)
.value("uint32", mscclppDataType_t::mscclppUint32)
.value("int", mscclppDataType_t::mscclppInt)
.value("int64", mscclppDataType_t::mscclppInt64)
.value("uint64", mscclppDataType_t::mscclppUint64)
.value("float16", mscclppDataType_t::mscclppFloat16)
.value("half", mscclppDataType_t::mscclppHalf)
.value("float32", mscclppDataType_t::mscclppFloat32)
.value("float", mscclppDataType_t::mscclppFloat)
.value("float64", mscclppDataType_t::mscclppFloat64)
.value("double", mscclppDataType_t::mscclppDouble)
#if defined(__CUDA_BF16_TYPES_EXIST__)
.value("bfloat16", mscclppDataType_t::mscclppBfloat16)
#endif
;
nb::class_<mscclppUniqueId>(m, "MscclppUniqueId")
.def_ro_static("__doc__", &DOC_MscclppUniqueId)
.def_static(