Stop py::array_t arguments from accepting arrays that do not match the C- or F-contiguity flags (#2484)

* Stop py::array_t arguments from accepting arrays that do not match the C- or F-contiguity flags

* Add trivially-contiguous arrays to the tests
This commit is contained in:
Yannick Jadoul
2020-09-15 14:50:51 +02:00
committed by GitHub
parent f12ec00d70
commit 9df13835c8
3 changed files with 86 additions and 1 deletions

View File

@@ -435,6 +435,52 @@ def test_index_using_ellipsis():
assert a.shape == (6,)
@pytest.mark.parametrize("forcecast", [False, True])
@pytest.mark.parametrize("contiguity", [None, 'C', 'F'])
@pytest.mark.parametrize("noconvert", [False, True])
@pytest.mark.filterwarnings(
"ignore:Casting complex values to real discards the imaginary part:numpy.ComplexWarning"
)
def test_argument_conversions(forcecast, contiguity, noconvert):
function_name = "accept_double"
if contiguity == 'C':
function_name += "_c_style"
elif contiguity == 'F':
function_name += "_f_style"
if forcecast:
function_name += "_forcecast"
if noconvert:
function_name += "_noconvert"
function = getattr(m, function_name)
for dtype in [np.dtype('float32'), np.dtype('float64'), np.dtype('complex128')]:
for order in ['C', 'F']:
for shape in [(2, 2), (1, 3, 1, 1), (1, 1, 1), (0,)]:
if not noconvert:
# If noconvert is not passed, only complex128 needs to be truncated and
# "cannot be safely obtained". So without `forcecast`, the argument shouldn't
# be accepted.
should_raise = dtype.name == 'complex128' and not forcecast
else:
# If noconvert is passed, only float64 and the matching order is accepted.
# If at most one dimension has a size greater than 1, the array is also
# trivially contiguous.
trivially_contiguous = sum(1 for d in shape if d > 1) <= 1
should_raise = (
dtype.name != 'float64' or
(contiguity is not None and
contiguity != order and
not trivially_contiguous)
)
array = np.zeros(shape, dtype=dtype, order=order)
if not should_raise:
function(array)
else:
with pytest.raises(TypeError, match="incompatible function arguments"):
function(array)
@pytest.mark.xfail("env.PYPY")
def test_dtype_refcount_leak():
from sys import getrefcount