diff --git a/python/src/main.cpp b/python/src/main.cpp index e3b7fb3..9573b80 100644 --- a/python/src/main.cpp +++ b/python/src/main.cpp @@ -13,7 +13,7 @@ PYBIND11_MODULE(komputepy, m) { .value("eStorage", kp::Tensor::TensorTypes::eStorage) .export_values(); - py::class_(m, "Tensor") + py::class_>(m, "Tensor") .def(py::init( [](const std::vector& data) { return std::unique_ptr(new kp::Tensor(data)); @@ -24,7 +24,75 @@ PYBIND11_MODULE(komputepy, m) { })) .def("data", &kp::Tensor::data); - py::class_(m, "OpBase"); + py::class_>(m, "Sequence") + .def("init", &kp::Sequence::init) + .def("begin", &kp::Sequence::begin) + .def("end", &kp::Sequence::end) + .def("eval", &kp::Sequence::eval) + .def("evalAsync", &kp::Sequence::evalAsync) + .def("evalAwait", &kp::Sequence::evalAwait) + .def("isRunning", &kp::Sequence::isRunning) + .def("isRecording", &kp::Sequence::isRecording) + .def("isInit", &kp::Sequence::isInit) + .def("recordOpTensorCreate", &kp::Sequence::record) + .def("recordOpTensorCopy", &kp::Sequence::record) + .def("recordOpTensorSyncDevice", &kp::Sequence::record) + .def("recordOpTensorSyncLocal", &kp::Sequence::record) + .def("recordOpAlgoMult", &kp::Sequence::record) + .def("recordOpAlgoBaseFile", &kp::Sequence::record) + .def("recordOpAlgoBaseData", &kp::Sequence::record>) + .def("recordOpAlgoLhsRhsOut", &kp::Sequence::record); + + py::class_(m, "Manager") + .def(py::init()) + .def(py::init( + [](uint32_t physicalDeviceIndex) { + return std::unique_ptr(new kp::Manager(physicalDeviceIndex)); + })) + .def(py::init( + [](uint32_t physicalDeviceIndex, const std::vector& familyQueueIndices) { + return std::unique_ptr(new kp::Manager(physicalDeviceIndex, familyQueueIndices)); + })) + .def("getOrCreateManagedSequence", &kp::Manager::getOrCreateManagedSequence) + .def("createManagedSequence", &kp::Manager::createManagedSequence, + py::arg("name"), py::arg("queueIndex") = 0) + .def("buildTensor", &kp::Manager::buildTensor, + py::arg("data"), py::arg("tensorType") = kp::Tensor::TensorTypes::eDevice) + .def("evalOpAsync", &kp::Manager::evalOpAsync) + .def("evalOpAsyncDefault", &kp::Manager::evalOpAsyncDefault) + .def("evalOpDefaultTensorCreate", &kp::Manager::evalOpDefault) + .def("evalOpDefaultTensorCopy", &kp::Manager::evalOpDefault) + .def("evalOpDefaultTensorSyncDevice", &kp::Manager::evalOpDefault) + .def("evalOpDefaultTensorSyncLocal", &kp::Manager::evalOpDefault) + .def("evalOpDefaultAlgoMult", &kp::Manager::evalOpDefault) + .def("evalOpDefaultAlgoBaseFile", &kp::Manager::evalOpDefault) + .def("evalOpDefaultAlgoBaseData", &kp::Manager::evalOpDefault>) + .def("evalOpDefaultAlgoLhsRhsOut", &kp::Manager::evalOpDefault) + .def("evalOpTensorCreate", &kp::Manager::evalOp) + .def("evalOpTensorCopy", &kp::Manager::evalOp) + .def("evalOpTensorSyncDevice", &kp::Manager::evalOp) + .def("evalOpTensorSyncLocal", &kp::Manager::evalOp) + .def("evalOpAlgoMult", &kp::Manager::evalOp) + .def("evalOpAlgoBaseFile", &kp::Manager::evalOp) + .def("evalOpAlgoBaseData", &kp::Manager::evalOp>) + .def("evalOpAlgoLhsRhsOut", &kp::Manager::evalOp) + .def("evalOpAsyncDefaultTensorCreate", &kp::Manager::evalOpAsyncDefault) + .def("evalOpAsyncDefaultTensorCopy", &kp::Manager::evalOpAsyncDefault) + .def("evalOpAsyncDefaultTensorSyncDevice", &kp::Manager::evalOpAsyncDefault) + .def("evalOpAsyncDefaultTensorSyncLocal", &kp::Manager::evalOpAsyncDefault) + .def("evalOpAsyncDefaultAlgoMult", &kp::Manager::evalOpAsyncDefault) + .def("evalOpAsyncDefaultAlgoBaseFile", &kp::Manager::evalOpAsyncDefault) + .def("evalOpAsyncDefaultAlgoBaseData", &kp::Manager::evalOpAsyncDefault>) + .def("evalOpAsyncDefaultAlgoLhsRhsOut", &kp::Manager::evalOpAsyncDefault) + .def("evalOpAsyncTensorCreate", &kp::Manager::evalOpAsync) + .def("evalOpAsyncTensorCopy", &kp::Manager::evalOpAsync) + .def("evalOpAsyncTensorSyncDevice", &kp::Manager::evalOpAsync) + .def("evalOpAsyncTensorSyncLocal", &kp::Manager::evalOpAsync) + .def("evalOpAsync", &kp::Manager::evalOpAsync) + .def("evalOpAsyncAlgoBaseFile", &kp::Manager::evalOpAsync) + .def("evalOpAsyncAlgoBase", &kp::Manager::evalOpAsync>) + .def("evalOpAsyncAlgoLhsRhsOut", &kp::Manager::evalOpAsync); + #ifdef VERSION_INFO m.attr("__version__") = VERSION_INFO;