mirror of
https://github.com/nomic-ai/kompute.git
synced 2026-05-11 08:59:59 +00:00
faster set_data()
This commit is contained in:
@@ -10,6 +10,8 @@ namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(kp, m) {
|
||||
|
||||
py::module_ np = py::module_::import("numpy");
|
||||
|
||||
#if KOMPUTE_ENABLE_SPDLOG
|
||||
spdlog::set_level(
|
||||
static_cast<spdlog::level::level_enum>(SPDLOG_ACTIVE_LEVEL));
|
||||
@@ -40,25 +42,19 @@ PYBIND11_MODULE(kp, m) {
|
||||
return std::unique_ptr<kp::Tensor>(new kp::Tensor(data, tensorTypes));
|
||||
}), "Initialiser with list of data components and tensor GPU memory type.")
|
||||
.def("data", &kp::Tensor::data, DOC(kp, Tensor, data))
|
||||
.def("numpy", [](kp::Tensor& self){
|
||||
ssize_t ndim = 1;
|
||||
std::vector<ssize_t> shape = { self.size() };
|
||||
std::vector<ssize_t> strides = { sizeof(float) };
|
||||
|
||||
return py::array(py::buffer_info(
|
||||
self.data().data(),
|
||||
sizeof(float),
|
||||
py::format_descriptor<float>::format(),
|
||||
ndim,
|
||||
shape,
|
||||
strides
|
||||
));
|
||||
.def("numpy", [](kp::Tensor& self) {
|
||||
return py::array(self.data().size(), self.data().data());
|
||||
}, "Returns stored data as a new numpy array.")
|
||||
.def("__getitem__", [](kp::Tensor &self, size_t index) -> float { return self.data()[index]; },
|
||||
"When only an index is necessary")
|
||||
.def("__setitem__", [](kp::Tensor &self, size_t index, float value) {
|
||||
self.data()[index] = value; })
|
||||
.def("set_data", &kp::Tensor::setData, "Overrides the data in the local Tensor memory.")
|
||||
.def("set_data", [np](kp::Tensor &self, const py::array_t<float> data){
|
||||
const py::array_t<float> flatdata = np.attr("ravel")(data);
|
||||
const py::buffer_info info = flatdata.request();
|
||||
const float* ptr = (float*) info.ptr;
|
||||
self.setData(std::vector<float>(ptr, ptr+flatdata.size()));
|
||||
}, "Overrides the data in the local Tensor memory.")
|
||||
.def("__iter__", [](kp::Tensor &self) {
|
||||
return py::make_iterator(self.data().begin(), self.data().end());
|
||||
}, py::keep_alive<0, 1>(), // Required to keep alive iterator while exists
|
||||
|
||||
Reference in New Issue
Block a user