faster set_data()

This commit is contained in:
alexander-g
2021-01-10 15:29:18 +01:00
parent 163335111f
commit 893fd4fc7c

View File

@@ -10,6 +10,8 @@ namespace py = pybind11;
PYBIND11_MODULE(kp, m) {
py::module_ np = py::module_::import("numpy");
#if KOMPUTE_ENABLE_SPDLOG
spdlog::set_level(
static_cast<spdlog::level::level_enum>(SPDLOG_ACTIVE_LEVEL));
@@ -40,25 +42,19 @@ PYBIND11_MODULE(kp, m) {
return std::unique_ptr<kp::Tensor>(new kp::Tensor(data, tensorTypes));
}), "Initialiser with list of data components and tensor GPU memory type.")
.def("data", &kp::Tensor::data, DOC(kp, Tensor, data))
.def("numpy", [](kp::Tensor& self){
ssize_t ndim = 1;
std::vector<ssize_t> shape = { self.size() };
std::vector<ssize_t> strides = { sizeof(float) };
return py::array(py::buffer_info(
self.data().data(),
sizeof(float),
py::format_descriptor<float>::format(),
ndim,
shape,
strides
));
.def("numpy", [](kp::Tensor& self) {
return py::array(self.data().size(), self.data().data());
}, "Returns stored data as a new numpy array.")
.def("__getitem__", [](kp::Tensor &self, size_t index) -> float { return self.data()[index]; },
"When only an index is necessary")
.def("__setitem__", [](kp::Tensor &self, size_t index, float value) {
self.data()[index] = value; })
.def("set_data", &kp::Tensor::setData, "Overrides the data in the local Tensor memory.")
.def("set_data", [np](kp::Tensor &self, const py::array_t<float> data){
const py::array_t<float> flatdata = np.attr("ravel")(data);
const py::buffer_info info = flatdata.request();
const float* ptr = (float*) info.ptr;
self.setData(std::vector<float>(ptr, ptr+flatdata.size()));
}, "Overrides the data in the local Tensor memory.")
.def("__iter__", [](kp::Tensor &self) {
return py::make_iterator(self.data().begin(), self.data().end());
}, py::keep_alive<0, 1>(), // Required to keep alive iterator while exists