Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
bc87ce8
Apply CEESD changes
MTCam Jul 18, 2024
c14b18c
fix full_like, zeros_like
matthiasdiener Jul 18, 2024
a83410f
invert container arithmetic checks
matthiasdiener Jul 18, 2024
bfd22a5
restore loop inference fallback
matthiasdiener Jul 18, 2024
b070816
Merge branch 'main' into production-pilot-update
MTCam Aug 29, 2024
3d715a8
Add zeros to NumpyFakeNumpyContext
MTCam Aug 29, 2024
1c6e6a4
Merge branch 'main' into production-pilot
matthiasdiener Sep 6, 2024
7d3e557
Merge branch 'main' into production-pilot
MTCam Oct 13, 2024
7f6840e
Merge branch 'main' into production-pilot
MTCam Oct 21, 2024
aeeb47d
Merge branch 'main' into production-pilot
MTCam Nov 15, 2024
624143a
Merge branch 'main' into production-pilot
MTCam Nov 19, 2024
b7dbab8
Merge branch 'main' into production-pilot
MTCam Nov 28, 2024
ee2f1d1
Merge branch 'main' into production-pilot
MTCam Dec 2, 2024
f66d33f
Merge branch 'main' into production-pilot
MTCam Dec 17, 2024
771e4ba
Merge branch 'main' into production-pilot
MTCam Dec 18, 2024
2894965
Merge branch 'main' into production-pilot
MTCam Jan 8, 2025
1fd45e9
Merge branch 'main' into production-pilot
MTCam Jan 10, 2025
efd5a8d
Merge branch 'main' into production-pilot
MTCam Jan 12, 2025
1cdd41a
Merge branch 'main' into production-pilot
MTCam Jan 27, 2025
044f191
Merge branch 'main' into production-pilot-merge-main
MTCam Feb 3, 2025
d5814cc
Merge branch 'main' into production-pilot
MTCam Mar 28, 2025
0bdcdf6
Merge branch 'main' into production-pilot
MTCam Mar 28, 2025
c9c3489
Revert "upgrade 'unevaluated array as argument' warning to error (#305)"
MTCam Apr 1, 2025
af74356
Merge branch 'main' into production-pilot
MTCam Apr 9, 2025
2fd7174
Revert to 0*array
MTCam Apr 9, 2025
633cb03
Merge branch 'main' into production-pilot
MTCam Apr 28, 2025
4aa9ee4
Merge branch 'main' into production-pilot
MTCam May 20, 2025
52bd8bc
Merge branch 'main' into production-pilot
MTCam Jun 2, 2025
1911d97
Merge branch 'main' into production-pilot
MTCam Jun 23, 2025
c05282d
Merge branch 'main' into production-pilot
MTCam Jul 1, 2025
b2e8ec3
Merge branch 'main' into production-pilot
MTCam Jul 8, 2025
48cc42e
Merge branch 'main' into production-pilot
MTCam Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from .impl.numpy import NumpyArrayContext
from .impl.pyopencl import PyOpenCLArrayContext
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
from .impl.numpy import NumpyArrayContext
from .loopy import make_loopy_program
from .pytest import (
PytestArrayContextFactory,
Expand Down
29 changes: 29 additions & 0 deletions arraycontext/impl/jax/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,35 @@ def __getattr__(self, name: str):
def zeros(self, shape: int | tuple[int, ...], dtype: DTypeLike) -> Array:
return cast("Array", cast("object", jnp.zeros(shape=shape, dtype=dtype)))

def empty_like(self, ary):
from warnings import warn
warn(f"{type(self._array_context).__name__}.np.empty_like is "
"deprecated and will stop working in 2023. Prefer actx.np.zeros_like "
"instead.",
DeprecationWarning, stacklevel=2)

def _empty_like(array):
return self._array_context.empty(array.shape, array.dtype)

return self._array_context._rec_map_container(_empty_like, ary)

def zeros_like(self, ary):
def _zeros_like(array):
return self._array_context.zeros(array.shape, array.dtype)

return self._array_context._rec_map_container(
_zeros_like, ary, default_scalar=0)

def ones_like(self, ary):
return self.full_like(ary, 1)

def full_like(self, ary, fill_value):
def _full_like(subary):
return jnp.full_like(subary, fill_value)

return self._array_context._rec_map_container(
_full_like, ary, default_scalar=fill_value)

@override
def _full_like_array(self,
ary: Array,
Expand Down
27 changes: 25 additions & 2 deletions arraycontext/impl/pyopencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,33 @@ def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationU
"to create this kernel?")

all_inames = default_entrypoint.all_inames()

# FIXME: This could be much smarter.
inner_iname = None

if "i0" in all_inames:
# import with underscore to avoid DeprecationWarning
# from arraycontext.metadata import _FirstAxisIsElementsTag
from meshmode.transform_metadata import FirstAxisIsElementsTag

if (len(default_entrypoint.instructions) == 1
and isinstance(default_entrypoint.instructions[0], lp.Assignment)
and any(isinstance(tag, FirstAxisIsElementsTag)
# FIXME: Firedrake branch lacks kernel tags
for tag in getattr(default_entrypoint, "tags", ()))):
stmt, = default_entrypoint.instructions

out_inames = [v.name for v in stmt.assignee.index_tuple]
assert out_inames
outer_iname = out_inames[0]
if len(out_inames) >= 2:
inner_iname = out_inames[1]

elif "iel" in all_inames:
outer_iname = "iel"

if "idof" in all_inames:
inner_iname = "idof"

elif "i0" in all_inames:
outer_iname = "i0"

if "i1" in all_inames:
Expand Down
1 change: 1 addition & 0 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ def _to_frozen(

pt_prg = pt.generate_loopy(transformed_dag,
options=opts,
cl_device=self.queue.device,
function_name=function_name,
target=self.get_target()
).bind_to_context(self.context)
Expand Down
53 changes: 28 additions & 25 deletions arraycontext/impl/pytato/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,7 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
return pytato_program, name_in_program_to_tags, name_in_program_to_axes


def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg,
fn_name="<unknown>"):
def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
input_kwargs_for_loopy = {}

for arg_id, arg in arg_id_to_arg.items():
Expand All @@ -565,20 +564,32 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg,
# got a frozen array => do nothing
pass
elif isinstance(arg, pt.Array):
# got an array expression => abort
raise ValueError(
f"Argument '{arg_id}' to the '{fn_name}' compiled function is a"
" pytato array expression. Evaluating it just-in-time"
" potentially causes a significant overhead on each call to the"
" function and is therefore unsupported. "
)
# got an array expression => evaluate it
from warnings import warn
warn(f"Argument array '{arg_id}' to a compiled function is "
"unevaluated. Evaluating just-in-time, at "
"considerable expense. This is deprecated and will stop "
"working in 2023. To avoid this warning, force evaluation "
"of all arguments via freeze/thaw.",
DeprecationWarning, stacklevel=4)

arg = actx.freeze(arg)
else:
raise NotImplementedError(type(arg))

input_kwargs_for_loopy[input_id_to_name_in_program[arg_id]] = arg

return input_kwargs_for_loopy


def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
from warnings import warn
warn("_args_to_cl_buffer has been renamed to"
" _args_to_device_buffers. This will be"
" an error in 2023.", DeprecationWarning, stacklevel=2)
return _args_to_device_buffers(actx, input_id_to_name_in_program,
arg_id_to_arg)

# }}}


Expand Down Expand Up @@ -634,7 +645,7 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction):
type of the callable.
"""
actx: PytatoPyOpenCLArrayContext
pytato_program: pt.target.loopy.BoundPyOpenCLExecutable
pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
Expand All @@ -645,10 +656,8 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
from .utils import get_cl_axes_from_pt_axes
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array

fn_name = self.pytato_program.program.entrypoint

input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)

if self.actx.profile_kernels:
import pyopencl as cl
Expand Down Expand Up @@ -681,7 +690,7 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction):
Name of the output array in the program.
"""
actx: PytatoPyOpenCLArrayContext
pytato_program: pt.target.loopy.BoundPyOpenCLExecutable
pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_tags: frozenset[Tag]
output_axes: tuple[pt.Axis, ...]
Expand All @@ -691,10 +700,8 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
from .utils import get_cl_axes_from_pt_axes
from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array

fn_name = self.pytato_program.program.entrypoint

input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)

if self.actx.profile_kernels:
import pyopencl as cl
Expand Down Expand Up @@ -734,18 +741,16 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction):
type of the callable.
"""
actx: PytatoJAXArrayContext
pytato_program: pt.target.python.BoundJAXPythonProgram
pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
name_in_program_to_tags: Mapping[str, frozenset[Tag]]
name_in_program_to_axes: Mapping[str, tuple[pt.Axis, ...]]
output_template: ArrayContainer

def __call__(self, arg_id_to_arg) -> ArrayContainer:
fn_name = self.pytato_program.entrypoint

input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)

out_dict = self.pytato_program(**input_kwargs_for_loopy)

Expand All @@ -767,17 +772,15 @@ class CompiledJAXFunctionReturningArray(CompiledFunction):
Name of the output array in the program.
"""
actx: PytatoJAXArrayContext
pytato_program: pt.target.python.BoundJAXPythonProgram
pytato_program: pt.target.BoundProgram
input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str]
output_tags: frozenset[Tag]
output_axes: tuple[pt.Axis, ...]
output_name: str

def __call__(self, arg_id_to_arg) -> ArrayContainer:
fn_name = self.pytato_program.entrypoint

input_kwargs_for_loopy = _args_to_device_buffers(
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)

_evt, out_dict = self.pytato_program(**input_kwargs_for_loopy)

Expand Down
22 changes: 22 additions & 0 deletions arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,28 @@ def __getattr__(self, name: str):
def zeros(self, shape: int | tuple[int, ...], dtype: DTypeLike) -> Array:
return pt.zeros(shape, dtype)

def zeros_like(self, ary):
def _zeros_like(array):
# return self._array_context.np.zeros(
# array.shape, array.dtype).copy(axes=array.axes,
# tags=array.tags)
return 0*array

return self._array_context._rec_map_container(
_zeros_like, ary, default_scalar=0)

def ones_like(self, ary):
return self.full_like(ary, 1)

def full_like(self, ary, fill_value):
def _full_like(subary):
# return pt.full(subary.shape, fill_value, subary.dtype).copy(
# axes=subary.axes, tags=subary.tags)
return fill_value * (0*subary + 1)

return self._array_context._rec_map_container(
_full_like, ary, default_scalar=fill_value)

@override
def _full_like_array(self,
ary: Array,
Expand Down
23 changes: 22 additions & 1 deletion arraycontext/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@

from arraycontext.context import ArrayContext


# {{{ array context factories


class PytestArrayContextFactory:
@classmethod
def is_available(cls) -> bool:
Expand Down Expand Up @@ -238,6 +238,27 @@ def __call__(self):
def __str__(self):
return "<PytatoJAXArrayContext>"

# {{{ _PytestArrayContextFactory


class _NumpyArrayContextForTests(NumpyArrayContext):
def transform_loopy_program(self, t_unit):
return t_unit


class _PytestNumpyArrayContextFactory(PytestArrayContextFactory):
def __init__(self, *args, **kwargs):
super().__init__()

def __call__(self):
return _NumpyArrayContextForTests()

def __str__(self):
return "<NumpyArrayContext>"

# }}}



# {{{ _PytestArrayContextFactory

Expand Down
30 changes: 16 additions & 14 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,8 +828,9 @@ def _check_allclose(f, arg1, arg2, atol=5.0e-14):
with pytest.raises(TypeError):
ary_of_dofs + dc_of_dofs

with pytest.raises(TypeError):
dc_of_dofs + ary_of_dofs
if not isinstance(actx, NumpyArrayContext):
with pytest.raises(TypeError):
dc_of_dofs + ary_of_dofs

with pytest.raises(TypeError):
ary_dof + dc_of_dofs
Expand Down Expand Up @@ -1031,7 +1032,11 @@ def test_flatten_with_leaf_class(actx_factory: ArrayContextFactory):
# {{{ test from_numpy and to_numpy

def test_numpy_conversion(actx_factory: ArrayContextFactory):
from arraycontext import NumpyArrayContext
actx = actx_factory()
if isinstance(actx, NumpyArrayContext):
pytest.skip("Irrelevant tests for NumpyArrayContext")

rng = np.random.default_rng()

nelements = 42
Expand Down Expand Up @@ -1164,6 +1169,7 @@ def test_actx_compile_kwargs(actx_factory: ArrayContextFactory):
def test_actx_compile_with_tuple_output_keys(actx_factory: ArrayContextFactory):
# arraycontext.git<=3c9aee68 would fail due to a bug in output
# key stringification logic.
from arraycontext import from_numpy, to_numpy
actx = actx_factory()
rng = np.random.default_rng()

Expand All @@ -1177,11 +1183,11 @@ def my_rhs(scale, vel):
v_x = rng.uniform(size=10)
v_y = rng.uniform(size=10)

vel = actx.from_numpy(Velocity2D(v_x, v_y, actx))
vel = from_numpy(Velocity2D(v_x, v_y, actx), actx)

scaled_speed = compiled_rhs(3.14, vel=vel)

result = actx.to_numpy(scaled_speed)[0, 0]
result = to_numpy(scaled_speed, actx)[0, 0]
np.testing.assert_allclose(result.u, -3.14*v_y)
np.testing.assert_allclose(result.v, 3.14*v_x)

Expand Down Expand Up @@ -1432,8 +1438,6 @@ class ArrayContainerWithNumpy:
u: np.ndarray
v: DOFArray

__array_ufunc__ = None


def test_array_container_with_numpy(actx_factory: ArrayContextFactory):
actx = actx_factory()
Expand Down Expand Up @@ -1553,16 +1557,14 @@ def test_compile_anonymous_function(actx_factory: ArrayContextFactory):

# See https://github.com/inducer/grudge/issues/287
actx = actx_factory()

ones = actx.thaw(actx.freeze(
actx.np.zeros(shape=(10, 4), dtype=np.float64) + 1
))

f = actx.compile(lambda x: 2*x+40)
np.testing.assert_allclose(actx.to_numpy(f(ones)), 42)

np.testing.assert_allclose(
actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
42)
f = actx.compile(partial(lambda x: 2*x+40))
np.testing.assert_allclose(actx.to_numpy(f(ones)), 42)
np.testing.assert_allclose(
actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
42)


@pytest.mark.parametrize(
Expand Down
3 changes: 1 addition & 2 deletions test/testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def array_context(self):

@with_container_arithmetic(
bcasts_across_obj_array=False,
container_types_bcast_across=(DOFArray, np.ndarray),
bcast_container_types=(DOFArray, np.ndarray),
matmul=True,
rel_comparison=True,
_cls_has_array_context_attr=True,
Expand Down Expand Up @@ -216,7 +216,6 @@ class Velocity2D:

__array_ufunc__: ClassVar[None] = None


@with_array_context.register(Velocity2D)
# https://github.com/python/mypy/issues/13040
def _with_actx_velocity_2d(ary, actx): # type: ignore[misc]
Expand Down