Better NumPy support

This commit is contained in:
Wenzel Jakob
2015-07-22 00:59:01 +02:00
parent bd4a529319
commit 2ac80e77aa
4 changed files with 103 additions and 1 deletions

View File

@@ -206,6 +206,22 @@ public:
TYPE_CASTER(std::string, "str");
};
#ifdef HAVE_WCHAR_H
template <> class type_caster<std::wstring> {
public:
bool load(PyObject *src, bool) {
const wchar_t *ptr = PyUnicode_AsWideCharString(src, nullptr);
if (!ptr) { PyErr_Clear(); return false; }
value = std::wstring(ptr);
return true;
}
static PyObject *cast(const std::wstring &src, return_value_policy /* policy */, PyObject * /* parent */) {
return PyUnicode_FromWideChar(src.c_str(), src.length());
}
TYPE_CASTER(std::wstring, "wstr");
};
#endif
template <> class type_caster<char> {
public:
bool load(PyObject *src, bool) {
@@ -474,6 +490,7 @@ TYPE_CASTER_PYTYPE(list)
TYPE_CASTER_PYTYPE(slice)
TYPE_CASTER_PYTYPE(tuple)
TYPE_CASTER_PYTYPE(function)
TYPE_CASTER_PYTYPE(array)
#undef TYPE_CASTER
#undef TYPE_CASTER_PYTYPE

View File

@@ -132,7 +132,10 @@ private:
entry = backup;
}
std::string signatures;
int it = 0;
while (entry) { /* Create pydoc entry */
if (sibling.ptr())
signatures += std::to_string(++it) + ". ";
signatures += "Signature : " + std::string(entry->signature) + "\n";
if (!entry->doc.empty())
signatures += "\n" + std::string(entry->doc) + "\n";

View File

@@ -322,6 +322,88 @@ private:
Py_buffer *view = nullptr;
};
class array : public buffer {
protected:
struct API {
enum Entries {
API_PyArray_Type = 2,
API_PyArray_DescrFromType = 45,
API_PyArray_NewCopy = 85,
API_PyArray_NewFromDescr = 94
};
static API lookup() {
PyObject *numpy = PyImport_ImportModule("numpy.core.multiarray");
PyObject *capsule = numpy ? PyObject_GetAttrString(numpy, "_ARRAY_API") : nullptr;
void **api_ptr = (void **) (capsule ? PyCapsule_GetPointer(capsule, NULL) : nullptr);
Py_XDECREF(capsule);
Py_XDECREF(numpy);
if (api_ptr == nullptr)
throw std::runtime_error("Could not acquire pointer to NumPy API!");
API api;
api.PyArray_DescrFromType = (decltype(api.PyArray_DescrFromType)) api_ptr[API_PyArray_DescrFromType];
api.PyArray_NewFromDescr = (decltype(api.PyArray_NewFromDescr)) api_ptr[API_PyArray_NewFromDescr];
api.PyArray_NewCopy = (decltype(api.PyArray_NewCopy)) api_ptr[API_PyArray_NewCopy];
api.PyArray_Type = (decltype(api.PyArray_Type)) api_ptr[API_PyArray_Type];
return api;
}
bool PyArray_Check(PyObject *obj) const {
return (bool) PyObject_TypeCheck(obj, PyArray_Type);
}
PyObject *(*PyArray_DescrFromType)(int);
PyObject *(*PyArray_NewFromDescr)
(PyTypeObject *, PyObject *, int, Py_intptr_t *,
Py_intptr_t *, void *, int, PyObject *);
PyObject *(*PyArray_NewCopy)(PyObject *, int);
PyTypeObject *PyArray_Type;
};
public:
PYTHON_OBJECT_DEFAULT(array, buffer, lookup_api().PyArray_Check)
template <typename Type> array(size_t size, const Type *ptr) {
API& api = lookup_api();
PyObject *descr = api.PyArray_DescrFromType(
(int) format_descriptor<Type>::value()[0]);
if (descr == nullptr)
throw std::runtime_error("NumPy: unsupported buffer format!");
Py_intptr_t shape = (Py_intptr_t) size;
PyObject *tmp = api.PyArray_NewFromDescr(
api.PyArray_Type, descr, 1, &shape, nullptr, (void *) ptr, 0, nullptr);
if (tmp == nullptr)
throw std::runtime_error("NumPy: unable to create array!");
m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
Py_DECREF(tmp);
if (m_ptr == nullptr)
throw std::runtime_error("NumPy: unable to copy array!");
}
array(const buffer_info &info) {
API& api = lookup_api();
if (info.format.size() != 1)
throw std::runtime_error("Unsupported buffer format!");
PyObject *descr = api.PyArray_DescrFromType(info.format[0]);
if (descr == nullptr)
throw std::runtime_error("NumPy: unsupported buffer format '" + info.format + "'!");
PyObject *tmp = api.PyArray_NewFromDescr(
api.PyArray_Type, descr, info.ndim, (Py_intptr_t *) &info.shape[0],
(Py_intptr_t *) &info.strides[0], info.ptr, 0, nullptr);
if (tmp == nullptr)
throw std::runtime_error("NumPy: unable to create array!");
m_ptr = api.PyArray_NewCopy(tmp, -1 /* any order */);
Py_DECREF(tmp);
if (m_ptr == nullptr)
throw std::runtime_error("NumPy: unable to copy array!");
}
protected:
static API &lookup_api() {
static API api = API::lookup();
return api;
}
};
NAMESPACE_BEGIN(detail)
inline internals &get_internals() {
static internals *internals_ptr = nullptr;