Skip to content

Commit eee34a3

Browse files
committed
Numba does not output numpy scalars
1 parent 9f18d56 commit eee34a3

File tree

3 files changed

+32
-12
lines changed

3 files changed

+32
-12
lines changed

tests/scalar/test_basic.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytensor.compile.mode import Mode
77
from pytensor.graph.fg import FunctionGraph
88
from pytensor.link.c.basic import DualLinker
9+
from pytensor.link.numba import NumbaLinker
910
from pytensor.scalar.basic import (
1011
EQ,
1112
ComplexError,
@@ -368,7 +369,9 @@ def _test_unary(unary_op, x_range):
368369
outi = fi(x_val)
369370
outf = ff(x_val)
370371

371-
assert outi.dtype == outf.dtype, "incorrect dtype"
372+
if not isinstance(ff.maker.linker, NumbaLinker):
373+
# Numba doesn't return numpy scalars
374+
assert outi.dtype == outf.dtype, "incorrect dtype"
372375
assert np.allclose(outi, outf), "insufficient precision"
373376

374377
@staticmethod
@@ -389,7 +392,9 @@ def _test_binary(binary_op, x_range, y_range):
389392
outi = fi(x_val, y_val)
390393
outf = ff(x_val, y_val)
391394

392-
assert outi.dtype == outf.dtype, "incorrect dtype"
395+
if not isinstance(ff.maker.linker, NumbaLinker):
396+
# Numba doesn't return numpy scalars
397+
assert outi.dtype == outf.dtype, "incorrect dtype"
393398
assert np.allclose(outi, outf), "insufficient precision"
394399

395400
def test_true_div(self):
@@ -414,7 +419,9 @@ def test_true_div(self):
414419
outi = fi(x_val, y_val)
415420
outf = ff(x_val, y_val)
416421

417-
assert outi.dtype == outf.dtype, "incorrect dtype"
422+
if not isinstance(ff.maker.linker, NumbaLinker):
423+
# Numba doesn't return numpy scalars
424+
assert outi.dtype == outf.dtype, "incorrect dtype"
418425
assert np.allclose(outi, outf), "insufficient precision"
419426

420427
def test_unary(self):

tests/tensor/test_basic.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pytensor.graph.basic import Apply, equal_computations
1919
from pytensor.graph.op import Op
2020
from pytensor.graph.replace import clone_replace
21+
from pytensor.link.numba import NumbaLinker
2122
from pytensor.raise_op import Assert
2223
from pytensor.scalar import autocast_float, autocast_float_as
2324
from pytensor.tensor import NoneConst, vectorize
@@ -2193,24 +2194,31 @@ def test_ScalarFromTensor(cast_policy):
21932194
assert ss.owner.op is scalar_from_tensor
21942195
assert ss.type.dtype == tc.type.dtype
21952196

2196-
v = eval_outputs([ss])
2197+
mode = get_default_mode()
2198+
v = eval_outputs([ss], mode=mode)
21972199

21982200
assert v == 56
2199-
assert v.shape == ()
2200-
2201-
if cast_policy == "custom":
2202-
assert isinstance(v, np.int8)
2203-
elif cast_policy == "numpy+floatX":
2204-
assert isinstance(v, np.int64)
2201+
if isinstance(mode.linker, NumbaLinker):
2202+
# Numba doesn't return numpy scalars
2203+
assert isinstance(v, int)
2204+
else:
2205+
assert v.shape == ()
2206+
if cast_policy == "custom":
2207+
assert isinstance(v, np.int8)
2208+
elif cast_policy == "numpy+floatX":
2209+
assert isinstance(v, np.int64)
22052210

22062211
pts = lscalar()
22072212
ss = scalar_from_tensor(pts)
22082213
ss.owner.op.grad([pts], [ss])
22092214
fff = function([pts], ss)
22102215
v = fff(np.asarray(5))
22112216
assert v == 5
2212-
assert isinstance(v, np.int64)
2213-
assert v.shape == ()
2217+
if isinstance(mode.linker, NumbaLinker):
2218+
assert isinstance(v, int)
2219+
else:
2220+
assert isinstance(v, np.int64)
2221+
assert v.shape == ()
22142222

22152223
with pytest.raises(TypeError):
22162224
scalar_from_tensor(vector())

tests/unittest_tools.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44
from copy import copy, deepcopy
55
from functools import wraps
6+
from numbers import Number
67

78
import numpy as np
89
import pytest
@@ -259,6 +260,10 @@ def _compile_and_check(
259260
numeric_outputs = outputs_function(*numeric_inputs)
260261
numeric_shapes = shapes_function(*numeric_inputs)
261262
for out, shape in zip(numeric_outputs, numeric_shapes, strict=True):
263+
if not hasattr(out, "shape"):
264+
# Numba downcasts scalars to native Python types, which don't have shape
265+
assert isinstance(out, Number)
266+
out = np.asarray(out)
262267
assert np.all(out.shape == shape), (out.shape, shape)
263268

264269

0 commit comments

Comments
 (0)