From 925ce61faec82e442935bb3072ca5e6385103c8e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 16 Feb 2024 17:15:47 -0600 Subject: [PATCH] add MPICupyArrayContext --- grudge/array_context.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/grudge/array_context.py b/grudge/array_context.py index c5672178d..b1113ce23 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -3,6 +3,7 @@ .. autoclass:: PytatoPyOpenCLArrayContext .. autoclass:: MPIBasedArrayContext .. autoclass:: MPIPyOpenCLArrayContext +.. autoclass:: MPICupyArrayContext .. class:: MPIPytatoArrayContext .. autofunction:: get_reasonable_array_context_class """ @@ -97,6 +98,8 @@ from arraycontext.container import ArrayContainer from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller +from arraycontext import CupyArrayContext + if TYPE_CHECKING: import pytato as pt from pytato.partition import PartId @@ -421,6 +424,26 @@ def clone(self): # }}} +# {{{ + +class MPICupyArrayContext(CupyArrayContext, MPIBasedArrayContext): + """An array context for using distributed computation with :mod:`cupy` + eager evaluation. + + .. autofunction:: __init__ + """ + + def __init__(self, mpi_communicator): + super().__init__() + + self.mpi_communicator = mpi_communicator + + def clone(self): + return type(self)(self.mpi_communicator) + +# }}} + + # {{{ distributed + pyopencl class MPIPyOpenCLArrayContext(PyOpenCLArrayContext, MPIBasedArrayContext):