mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
wip
This commit is contained in:
@@ -75,15 +75,17 @@ void register_algorithm(nb::module_& m) {
|
||||
[](Algorithm& self, std::shared_ptr<Communicator> comm, uintptr_t input, uintptr_t output,
|
||||
size_t inputSize, size_t outputSize, DataType dtype, ReduceOp op, uintptr_t stream,
|
||||
std::shared_ptr<Executor> executor, int nBlocks, int nThreadsPerBlock, bool symmetricMemory,
|
||||
std::unordered_map<std::string, uintptr_t> extras) {
|
||||
std::unordered_map<std::string, uintptr_t> extras, int32_t accumDtype) {
|
||||
return self.execute(comm, reinterpret_cast<const void*>(input), reinterpret_cast<void*>(output),
|
||||
inputSize, outputSize, dtype, op, reinterpret_cast<cudaStream_t>(stream), executor,
|
||||
nBlocks, nThreadsPerBlock, symmetricMemory, extras);
|
||||
nBlocks, nThreadsPerBlock, symmetricMemory, extras,
|
||||
static_cast<DataType>(accumDtype));
|
||||
},
|
||||
nb::arg("comm"), nb::arg("input"), nb::arg("output"), nb::arg("input_size"), nb::arg("output_size"),
|
||||
nb::arg("dtype"), nb::arg("op") = ReduceOp::NOP, nb::arg("stream") = 0, nb::arg("executor") = nullptr,
|
||||
nb::arg("n_blocks") = 0, nb::arg("n_threads_per_block") = 0, nb::arg("symmetric_memory") = false,
|
||||
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>())
|
||||
nb::arg("extras") = std::unordered_map<std::string, uintptr_t>(),
|
||||
nb::arg("accum_dtype") = static_cast<int32_t>(DataType::AUTO))
|
||||
.def("reset", &Algorithm::reset);
|
||||
|
||||
nb::class_<Algorithm::Constraint>(algorithmClass, "Constraint")
|
||||
|
||||
Reference in New Issue
Block a user