From 2de6052e0bba4f890530a3a9f0f805420da40969 Mon Sep 17 00:00:00 2001 From: Nick Koskelo Date: Wed, 8 Jan 2025 22:04:17 +0000 Subject: [PATCH 1/4] Add an implementation of vdot to the PytatoPyOpenCLArrayContext np namespace. --- arraycontext/impl/pytato/fake_numpy.py | 6 ++++ test/test_arraycontext.py | 43 ++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index d7072855..ac5978ef 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -239,4 +239,10 @@ def amin(self, a, axis=None): def absolute(self, a): return self.abs(a) + def vdot(self, a: Array, b: Array, order_a: str = "C", order_b: str = "C"): + + flat_a = self.ravel(a, order_a) + flat_b = self.ravel(b, order_b) + + return rec_multimap_array_container(pt.vdot, flat_a, flat_b) # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 14d24dd4..8a0b6016 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -271,9 +271,52 @@ def evaluate(np_, *args_): assert_close_to_numpy_in_containers(actx, evaluate, args) + +@pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [ + # float only + ("arctan2", 2, np.float64), + ("minimum", 2, np.float64), + ("maximum", 2, np.float64), + ("where", 3, np.float64), + ("min", 1, np.float64), + ("max", 1, np.float64), + ("any", 1, np.float64), + ("all", 1, np.float64), + ("arctan", 1, np.float64), + + # float + complex + ("sin", 1, np.float64), + ("sin", 1, np.complex128), + ("exp", 1, np.float64), + ("exp", 1, np.complex128), + ("conj", 1, np.float64), + ("conj", 1, np.complex128), + ("vdot", 2, np.float64), + ("vdot", 2, np.complex128), + ("abs", 1, np.float64), + ("abs", 1, np.complex128), + ("sum", 1, np.float64), + ("sum", 1, np.complex64), + ("isnan", 1, np.float64), + ]) +def test_array_context_np_workalike_with_scalars(actx_factory, sym_name, n_args, dtype): + actx = actx_factory() + if not hasattr(actx.np, sym_name): + pytest.skip(f"'{sym_name}' not implemented on '{type(actx).__name__}'") if sym_name in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]: pytest.skip(f"'{sym_name}' not supported on scalars") + c_to_numpy_arc_functions = { + "atan": "arctan", + "atan2": "arctan2", + } + + def evaluate(np_, *args_): + func = getattr(np_, sym_name, + getattr(np_, c_to_numpy_arc_functions.get(sym_name, sym_name))) + + return func(*args_) + args = [randn(0, dtype)[()] for i in range(n_args)] assert_close_to_numpy(actx, evaluate, args) From 0d1100f54ad98a4298087e818dca02ae3ee842ee Mon Sep 17 00:00:00 2001 From: Nick Koskelo Date: Wed, 8 Jan 2025 23:25:19 +0000 Subject: [PATCH 2/4] Remove the tests that are just skipped for scalars. --- test/test_arraycontext.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 8a0b6016..0e03920c 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -277,11 +277,6 @@ def evaluate(np_, *args_): ("arctan2", 2, np.float64), ("minimum", 2, np.float64), ("maximum", 2, np.float64), - ("where", 3, np.float64), - ("min", 1, np.float64), - ("max", 1, np.float64), - ("any", 1, np.float64), - ("all", 1, np.float64), ("arctan", 1, np.float64), # float + complex @@ -289,22 +284,14 @@ def evaluate(np_, *args_): ("sin", 1, np.complex128), ("exp", 1, np.float64), ("exp", 1, np.complex128), - ("conj", 1, np.float64), - ("conj", 1, np.complex128), - ("vdot", 2, np.float64), - ("vdot", 2, np.complex128), ("abs", 1, np.float64), ("abs", 1, np.complex128), - ("sum", 1, np.float64), - ("sum", 1, np.complex64), ("isnan", 1, np.float64), ]) def test_array_context_np_workalike_with_scalars(actx_factory, sym_name, n_args, dtype): actx = actx_factory() if not hasattr(actx.np, sym_name): pytest.skip(f"'{sym_name}' not implemented on '{type(actx).__name__}'") - if sym_name in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]: - pytest.skip(f"'{sym_name}' not supported on scalars") c_to_numpy_arc_functions = { "atan": "arctan", From 06fee0c1a74f2e88978aec666b1785db98b40052 Mon Sep 17 00:00:00 2001 From: Nick Koskelo Date: Thu, 9 Jan 2025 20:20:08 +0000 Subject: [PATCH 3/4] Respond to comments. --- arraycontext/impl/pytato/fake_numpy.py | 7 ++--- test/test_arraycontext.py | 39 +++----------------------- 2 files changed, 6 insertions(+), 40 deletions(-) diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index ac5978ef..21dc71ed 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -239,10 +239,7 @@ def amin(self, a, axis=None): def absolute(self, a): return self.abs(a) - def vdot(self, a: Array, b: Array, order_a: str = "C", order_b: str = "C"): + def vdot(self, a: Array, b: Array): - flat_a = self.ravel(a, order_a) - flat_b = self.ravel(b, order_b) - - return rec_multimap_array_container(pt.vdot, flat_a, flat_b) + return rec_multimap_array_container(pt.vdot, a, b) # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 0e03920c..ad2cbb10 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -271,41 +271,10 @@ def evaluate(np_, *args_): assert_close_to_numpy_in_containers(actx, evaluate, args) - -@pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [ - # float only - ("arctan2", 2, np.float64), - ("minimum", 2, np.float64), - ("maximum", 2, np.float64), - ("arctan", 1, np.float64), - - # float + complex - ("sin", 1, np.float64), - ("sin", 1, np.complex128), - ("exp", 1, np.float64), - ("exp", 1, np.complex128), - ("abs", 1, np.float64), - ("abs", 1, np.complex128), - ("isnan", 1, np.float64), - ]) -def test_array_context_np_workalike_with_scalars(actx_factory, sym_name, n_args, dtype): - actx = actx_factory() - if not hasattr(actx.np, sym_name): - pytest.skip(f"'{sym_name}' not implemented on '{type(actx).__name__}'") - - c_to_numpy_arc_functions = { - "atan": "arctan", - "atan2": "arctan2", - } - - def evaluate(np_, *args_): - func = getattr(np_, sym_name, - getattr(np_, c_to_numpy_arc_functions.get(sym_name, sym_name))) - - return func(*args_) - - args = [randn(0, dtype)[()] for i in range(n_args)] - assert_close_to_numpy(actx, evaluate, args) + if sym_name not in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]: + # Scalar arguments are supported. + args = [randn(0, dtype)[()] for i in range(n_args)] + assert_close_to_numpy(actx, evaluate, args) @pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [ From 1aae2501c0130afa983e81842df252e9ce039590 Mon Sep 17 00:00:00 2001 From: Nick Koskelo Date: Thu, 9 Jan 2025 20:24:01 +0000 Subject: [PATCH 4/4] Ruff version needed to be updated locally. --- test/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 807d652d..3b74a42f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -27,7 +27,7 @@ THE SOFTWARE. """ import logging -from typing import Optional, cast +from typing import cast import numpy as np import pytest @@ -63,7 +63,7 @@ def test_dataclass_array_container() -> None: class ArrayContainerWithOptional: x: np.ndarray # Deliberately left as Optional to test compatibility. - y: Optional[np.ndarray] # noqa: UP007 + y: np.ndarray | None with pytest.raises(TypeError, match="Field 'y' union contains non-array"): # NOTE: cannot have wrapped annotations (here by `Optional`)