diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 21bcd714..3dc87adb 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -79,6 +79,7 @@ from .impl.numpy import NumpyArrayContext from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext +from .impl.numpy import NumpyArrayContext from .loopy import make_loopy_program from .pytest import ( PytestArrayContextFactory, diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 48d8a4af..feab76a1 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -85,6 +85,35 @@ def __getattr__(self, name: str): def zeros(self, shape: int | tuple[int, ...], dtype: DTypeLike) -> Array: return cast("Array", cast("object", jnp.zeros(shape=shape, dtype=dtype))) + def empty_like(self, ary): + from warnings import warn + warn(f"{type(self._array_context).__name__}.np.empty_like is " + "deprecated and will stop working in 2023. Prefer actx.np.zeros_like " + "instead.", + DeprecationWarning, stacklevel=2) + + def _empty_like(array): + return self._array_context.empty(array.shape, array.dtype) + + return self._array_context._rec_map_container(_empty_like, ary) + + def zeros_like(self, ary): + def _zeros_like(array): + return self._array_context.zeros(array.shape, array.dtype) + + return self._array_context._rec_map_container( + _zeros_like, ary, default_scalar=0) + + def ones_like(self, ary): + return self.full_like(ary, 1) + + def full_like(self, ary, fill_value): + def _full_like(subary): + return jnp.full_like(subary, fill_value) + + return self._array_context._rec_map_container( + _full_like, ary, default_scalar=fill_value) + @override def _full_like_array(self, ary: Array, diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index c3db39da..988d6a8f 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -375,10 +375,33 @@ def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationU "to create this kernel?") all_inames = default_entrypoint.all_inames() - + # FIXME: This could be much smarter. inner_iname = None - if "i0" in all_inames: + # import with underscore to avoid DeprecationWarning + # from arraycontext.metadata import _FirstAxisIsElementsTag + from meshmode.transform_metadata import FirstAxisIsElementsTag + + if (len(default_entrypoint.instructions) == 1 + and isinstance(default_entrypoint.instructions[0], lp.Assignment) + and any(isinstance(tag, FirstAxisIsElementsTag) + # FIXME: Firedrake branch lacks kernel tags + for tag in getattr(default_entrypoint, "tags", ()))): + stmt, = default_entrypoint.instructions + + out_inames = [v.name for v in stmt.assignee.index_tuple] + assert out_inames + outer_iname = out_inames[0] + if len(out_inames) >= 2: + inner_iname = out_inames[1] + + elif "iel" in all_inames: + outer_iname = "iel" + + if "idof" in all_inames: + inner_iname = "idof" + + elif "i0" in all_inames: outer_iname = "i0" if "i1" in all_inames: diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index d770c7f7..9374e0ad 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -673,6 +673,7 @@ def _to_frozen( pt_prg = pt.generate_loopy(transformed_dag, options=opts, + cl_device=self.queue.device, function_name=function_name, target=self.get_target() ).bind_to_context(self.context) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 323eb791..51b7b025 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -543,8 +543,7 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): return pytato_program, name_in_program_to_tags, name_in_program_to_axes -def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg, - fn_name=""): +def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): input_kwargs_for_loopy = {} for arg_id, arg in arg_id_to_arg.items(): @@ -565,13 +564,16 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg, # got a frozen array => do nothing pass elif isinstance(arg, pt.Array): - # got an array expression => abort - raise ValueError( - f"Argument '{arg_id}' to the '{fn_name}' compiled function is a" - " pytato array expression. Evaluating it just-in-time" - " potentially causes a significant overhead on each call to the" - " function and is therefore unsupported. " - ) + # got an array expression => evaluate it + from warnings import warn + warn(f"Argument array '{arg_id}' to a compiled function is " + "unevaluated. Evaluating just-in-time, at " + "considerable expense. This is deprecated and will stop " + "working in 2023. To avoid this warning, force evaluation " + "of all arguments via freeze/thaw.", + DeprecationWarning, stacklevel=4) + + arg = actx.freeze(arg) else: raise NotImplementedError(type(arg)) @@ -579,6 +581,15 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg, return input_kwargs_for_loopy + +def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): + from warnings import warn + warn("_args_to_cl_buffer has been renamed to" + " _args_to_device_buffers. This will be" + " an error in 2023.", DeprecationWarning, stacklevel=2) + return _args_to_device_buffers(actx, input_id_to_name_in_program, + arg_id_to_arg) + # }}} @@ -634,7 +645,7 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction): type of the callable. """ actx: PytatoPyOpenCLArrayContext - pytato_program: pt.target.loopy.BoundPyOpenCLExecutable + pytato_program: pt.target.BoundProgram input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] name_in_program_to_tags: Mapping[str, frozenset[Tag]] @@ -645,10 +656,8 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: from .utils import get_cl_axes_from_pt_axes from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array - fn_name = self.pytato_program.program.entrypoint - input_kwargs_for_loopy = _args_to_device_buffers( - self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) if self.actx.profile_kernels: import pyopencl as cl @@ -681,7 +690,7 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction): Name of the output array in the program. """ actx: PytatoPyOpenCLArrayContext - pytato_program: pt.target.loopy.BoundPyOpenCLExecutable + pytato_program: pt.target.BoundProgram input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_tags: frozenset[Tag] output_axes: tuple[pt.Axis, ...] @@ -691,10 +700,8 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: from .utils import get_cl_axes_from_pt_axes from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array - fn_name = self.pytato_program.program.entrypoint - input_kwargs_for_loopy = _args_to_device_buffers( - self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) if self.actx.profile_kernels: import pyopencl as cl @@ -734,7 +741,7 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): type of the callable. """ actx: PytatoJAXArrayContext - pytato_program: pt.target.python.BoundJAXPythonProgram + pytato_program: pt.target.BoundProgram input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] name_in_program_to_tags: Mapping[str, frozenset[Tag]] @@ -742,10 +749,8 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: - fn_name = self.pytato_program.entrypoint - input_kwargs_for_loopy = _args_to_device_buffers( - self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) out_dict = self.pytato_program(**input_kwargs_for_loopy) @@ -767,17 +772,15 @@ class CompiledJAXFunctionReturningArray(CompiledFunction): Name of the output array in the program. """ actx: PytatoJAXArrayContext - pytato_program: pt.target.python.BoundJAXPythonProgram + pytato_program: pt.target.BoundProgram input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_tags: frozenset[Tag] output_axes: tuple[pt.Axis, ...] output_name: str def __call__(self, arg_id_to_arg) -> ArrayContainer: - fn_name = self.pytato_program.entrypoint - input_kwargs_for_loopy = _args_to_device_buffers( - self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) _evt, out_dict = self.pytato_program(**input_kwargs_for_loopy) diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 75274580..db84d1fa 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -107,6 +107,28 @@ def __getattr__(self, name: str): def zeros(self, shape: int | tuple[int, ...], dtype: DTypeLike) -> Array: return pt.zeros(shape, dtype) + def zeros_like(self, ary): + def _zeros_like(array): + # return self._array_context.np.zeros( + # array.shape, array.dtype).copy(axes=array.axes, + # tags=array.tags) + return 0*array + + return self._array_context._rec_map_container( + _zeros_like, ary, default_scalar=0) + + def ones_like(self, ary): + return self.full_like(ary, 1) + + def full_like(self, ary, fill_value): + def _full_like(subary): + # return pt.full(subary.shape, fill_value, subary.dtype).copy( + # axes=subary.axes, tags=subary.tags) + return fill_value * (0*subary + 1) + + return self._array_context._rec_map_container( + _full_like, ary, default_scalar=fill_value) + @override def _full_like_array(self, ary: Array, diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 75e19bc6..8b9381d1 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -45,9 +45,9 @@ from arraycontext.context import ArrayContext - # {{{ array context factories + class PytestArrayContextFactory: @classmethod def is_available(cls) -> bool: @@ -238,6 +238,27 @@ def __call__(self): def __str__(self): return "" +# {{{ _PytestArrayContextFactory + + +class _NumpyArrayContextForTests(NumpyArrayContext): + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PytestNumpyArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + super().__init__() + + def __call__(self): + return _NumpyArrayContextForTests() + + def __str__(self): + return "" + +# }}} + + # {{{ _PytestArrayContextFactory diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 2dc47cc5..d3e84218 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -828,8 +828,9 @@ def _check_allclose(f, arg1, arg2, atol=5.0e-14): with pytest.raises(TypeError): ary_of_dofs + dc_of_dofs - with pytest.raises(TypeError): - dc_of_dofs + ary_of_dofs + if not isinstance(actx, NumpyArrayContext): + with pytest.raises(TypeError): + dc_of_dofs + ary_of_dofs with pytest.raises(TypeError): ary_dof + dc_of_dofs @@ -1031,7 +1032,11 @@ def test_flatten_with_leaf_class(actx_factory: ArrayContextFactory): # {{{ test from_numpy and to_numpy def test_numpy_conversion(actx_factory: ArrayContextFactory): + from arraycontext import NumpyArrayContext actx = actx_factory() + if isinstance(actx, NumpyArrayContext): + pytest.skip("Irrelevant tests for NumpyArrayContext") + rng = np.random.default_rng() nelements = 42 @@ -1164,6 +1169,7 @@ def test_actx_compile_kwargs(actx_factory: ArrayContextFactory): def test_actx_compile_with_tuple_output_keys(actx_factory: ArrayContextFactory): # arraycontext.git<=3c9aee68 would fail due to a bug in output # key stringification logic. + from arraycontext import from_numpy, to_numpy actx = actx_factory() rng = np.random.default_rng() @@ -1177,11 +1183,11 @@ def my_rhs(scale, vel): v_x = rng.uniform(size=10) v_y = rng.uniform(size=10) - vel = actx.from_numpy(Velocity2D(v_x, v_y, actx)) + vel = from_numpy(Velocity2D(v_x, v_y, actx), actx) scaled_speed = compiled_rhs(3.14, vel=vel) - result = actx.to_numpy(scaled_speed)[0, 0] + result = to_numpy(scaled_speed, actx)[0, 0] np.testing.assert_allclose(result.u, -3.14*v_y) np.testing.assert_allclose(result.v, 3.14*v_x) @@ -1432,8 +1438,6 @@ class ArrayContainerWithNumpy: u: np.ndarray v: DOFArray - __array_ufunc__ = None - def test_array_container_with_numpy(actx_factory: ArrayContextFactory): actx = actx_factory() @@ -1553,16 +1557,14 @@ def test_compile_anonymous_function(actx_factory: ArrayContextFactory): # See https://github.com/inducer/grudge/issues/287 actx = actx_factory() - - ones = actx.thaw(actx.freeze( - actx.np.zeros(shape=(10, 4), dtype=np.float64) + 1 - )) - f = actx.compile(lambda x: 2*x+40) - np.testing.assert_allclose(actx.to_numpy(f(ones)), 42) - + np.testing.assert_allclose( + actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))), + 42) f = actx.compile(partial(lambda x: 2*x+40)) - np.testing.assert_allclose(actx.to_numpy(f(ones)), 42) + np.testing.assert_allclose( + actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))), + 42) @pytest.mark.parametrize( diff --git a/test/testlib.py b/test/testlib.py index 81e7b1b4..508f9228 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -163,7 +163,7 @@ def array_context(self): @with_container_arithmetic( bcasts_across_obj_array=False, - container_types_bcast_across=(DOFArray, np.ndarray), + bcast_container_types=(DOFArray, np.ndarray), matmul=True, rel_comparison=True, _cls_has_array_context_attr=True, @@ -216,7 +216,6 @@ class Velocity2D: __array_ufunc__: ClassVar[None] = None - @with_array_context.register(Velocity2D) # https://github.com/python/mypy/issues/13040 def _with_actx_velocity_2d(ary, actx): # type: ignore[misc]