mirror of
https://github.com/nomic-ai/kompute.git
synced 2026-05-12 09:25:39 +00:00
Added set and get functions
This commit is contained in:
@@ -39,10 +39,14 @@ PYBIND11_MODULE(kp, m) {
|
||||
return std::unique_ptr<kp::Tensor>(new kp::Tensor(data, tensorTypes));
|
||||
}), "Initialiser with list of data components and tensor GPU memory type.")
|
||||
.def("data", &kp::Tensor::data, DOC(kp, Tensor, data))
|
||||
.def("get", [](kp::Tensor &self, uint32_t index) -> float { return self.data()[index]; },
|
||||
"When only an index is necessary")
|
||||
.def("set", [](kp::Tensor &self, uint32_t index, float value) {
|
||||
self.data()[index] = value; })
|
||||
.def("set", &kp::Tensor::setData, "Overrides the data in the local Tensor memory.")
|
||||
.def("size", &kp::Tensor::size, "Retrieves the size of the Tensor data as per the local Tensor memory.")
|
||||
.def("tensor_type", &kp::Tensor::tensorType, "Retreves the memory type of the tensor.")
|
||||
.def("is_init", &kp::Tensor::isInit, "Checks whether the tensor GPU memory has been initialised.")
|
||||
.def("set_data", &kp::Tensor::setData, "Overrides the data in the local Tensor memory.")
|
||||
.def("map_data_from_host", &kp::Tensor::mapDataFromHostMemory, "Maps data into GPU memory from tensor local data.")
|
||||
.def("map_data_into_host", &kp::Tensor::mapDataIntoHostMemory, "Maps data from GPU memory into tensor local data.");
|
||||
|
||||
@@ -74,12 +78,12 @@ PYBIND11_MODULE(kp, m) {
|
||||
"Records an operation using a custom shader provided from a shader path")
|
||||
.def("record_algo_data", [](kp::Sequence &self,
|
||||
std::vector<std::shared_ptr<kp::Tensor>> tensors,
|
||||
py::bytes &bytes) {
|
||||
py::bytes &bytes) -> float {
|
||||
// Bytes have to be converted into std::vector
|
||||
py::buffer_info info(py::buffer(bytes).request());
|
||||
const char *data = reinterpret_cast<const char *>(info.ptr);
|
||||
size_t length = static_cast<size_t>(info.size);
|
||||
self.record<kp::OpAlgoBase>(
|
||||
return self.record<kp::OpAlgoBase>(
|
||||
tensors,
|
||||
std::vector<char>(data, data + length));
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user