diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index d7072855..21dc71ed 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -239,4 +239,7 @@ def amin(self, a, axis=None): def absolute(self, a): return self.abs(a) + def vdot(self, a: Array, b: Array): + + return rec_multimap_array_container(pt.vdot, a, b) # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 14d24dd4..ad2cbb10 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -271,11 +271,10 @@ def evaluate(np_, *args_): assert_close_to_numpy_in_containers(actx, evaluate, args) - if sym_name in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]: - pytest.skip(f"'{sym_name}' not supported on scalars") - - args = [randn(0, dtype)[()] for i in range(n_args)] - assert_close_to_numpy(actx, evaluate, args) + if sym_name not in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]: + # Scalar arguments are supported. + args = [randn(0, dtype)[()] for i in range(n_args)] + assert_close_to_numpy(actx, evaluate, args) @pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [ diff --git a/test/test_utils.py b/test/test_utils.py index 807d652d..3b74a42f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -27,7 +27,7 @@ THE SOFTWARE. """ import logging -from typing import Optional, cast +from typing import cast import numpy as np import pytest @@ -63,7 +63,7 @@ def test_dataclass_array_container() -> None: class ArrayContainerWithOptional: x: np.ndarray # Deliberately left as Optional to test compatibility. - y: Optional[np.ndarray] # noqa: UP007 + y: np.ndarray | None with pytest.raises(TypeError, match="Field 'y' union contains non-array"): # NOTE: cannot have wrapped annotations (here by `Optional`)