From 135bf346a6034fbd83eaace81df349705ee0f55b Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 26 Jan 2024 13:47:13 -0600 Subject: [PATCH 1/2] add make_distributed_send_ref_holder --- pytato/__init__.py | 4 +++- pytato/distributed/nodes.py | 26 +++++++++++++++++--------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/pytato/__init__.py b/pytato/__init__.py index 572e4a7ab..e593d2867 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -100,6 +100,7 @@ def set_debug_enabled(flag: bool) -> None: from pytato.distributed.nodes import (make_distributed_send, make_distributed_recv, DistributedRecv, DistributedSend, DistributedSendRefHolder, + make_distributed_send_ref_holder, staple_distributed_send) from pytato.distributed.partition import ( find_distributed_partition, DistributedGraphPart, DistributedGraphPartition) @@ -161,7 +162,8 @@ def set_debug_enabled(flag: bool) -> None: "trace_call", "make_distributed_recv", "make_distributed_send", "DistributedRecv", - "DistributedSend", "staple_distributed_send", "DistributedSendRefHolder", + "DistributedSend", "make_distributed_send_ref_holder", + "staple_distributed_send", "DistributedSendRefHolder", "DistributedGraphPart", "DistributedGraphPartition", diff --git a/pytato/distributed/nodes.py b/pytato/distributed/nodes.py index 465bda312..b18b351a3 100644 --- a/pytato/distributed/nodes.py +++ b/pytato/distributed/nodes.py @@ -10,14 +10,11 @@ These functions aid in creating communication nodes: +.. autofunction:: make_distributed_send +.. autofunction:: make_distributed_send_ref_holder .. autofunction:: staple_distributed_send .. autofunction:: make_distributed_recv -For completeness, individual (non-held/"stapled") :class:`DistributedSend` nodes -can be made via this function: - -.. autofunction:: make_distributed_send - Redirections for the documentation tool ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -226,6 +223,16 @@ def make_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagTyp tags=send_tags) +def make_distributed_send_ref_holder( + send: DistributedSend, + passthrough_data: Array, + tags: FrozenSet[Tag] = frozenset() + ) -> DistributedSendRefHolder: + """Make a :class:`DistributedSendRefHolder` object.""" + return DistributedSendRefHolder( + send=send, passthrough_data=passthrough_data, tags=tags) + + def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagType, stapled_to: Array, *, send_tags: FrozenSet[Tag] = frozenset(), @@ -233,10 +240,11 @@ def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagT DistributedSendRefHolder: """Make a :class:`DistributedSend` object wrapped in a :class:`DistributedSendRefHolder` object.""" - return DistributedSendRefHolder( - send=DistributedSend(data=sent_data, dest_rank=dest_rank, - comm_tag=comm_tag, tags=send_tags), - passthrough_data=stapled_to, tags=ref_holder_tags) + return make_distributed_send_ref_holder( + send=DistributedSend(data=sent_data, dest_rank=dest_rank, + comm_tag=comm_tag, tags=send_tags), + passthrough_data=stapled_to, + tags=ref_holder_tags) def make_distributed_recv(src_rank: int, comm_tag: CommTagType, From aabb34440d68ddd650739cb23348dabf5eab3e04 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 26 Jan 2024 15:22:20 -0600 Subject: [PATCH 2/2] add default tags to send/recv nodes --- pytato/distributed/nodes.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pytato/distributed/nodes.py b/pytato/distributed/nodes.py index b18b351a3..38fec57eb 100644 --- a/pytato/distributed/nodes.py +++ b/pytato/distributed/nodes.py @@ -61,7 +61,7 @@ from pytato.array import ( Array, _SuppliedShapeAndDtypeMixin, ShapeType, AxesT, - _get_default_axes, ConvertibleToShape, normalize_shape) + _get_default_axes, _get_default_tags, ConvertibleToShape, normalize_shape) CommTagType = Hashable @@ -220,7 +220,7 @@ def make_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagTyp DistributedSend: """Make a :class:`DistributedSend` object.""" return DistributedSend(data=sent_data, dest_rank=dest_rank, comm_tag=comm_tag, - tags=send_tags) + tags=(send_tags | _get_default_tags())) def make_distributed_send_ref_holder( @@ -230,7 +230,8 @@ def make_distributed_send_ref_holder( ) -> DistributedSendRefHolder: """Make a :class:`DistributedSendRefHolder` object.""" return DistributedSendRefHolder( - send=send, passthrough_data=passthrough_data, tags=tags) + send=send, passthrough_data=passthrough_data, + tags=(tags | _get_default_tags())) def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagType, @@ -241,8 +242,9 @@ def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagT """Make a :class:`DistributedSend` object wrapped in a :class:`DistributedSendRefHolder` object.""" return make_distributed_send_ref_holder( - send=DistributedSend(data=sent_data, dest_rank=dest_rank, - comm_tag=comm_tag, tags=send_tags), + send=make_distributed_send( + sent_data=sent_data, dest_rank=dest_rank, comm_tag=comm_tag, + send_tags=send_tags), passthrough_data=stapled_to, tags=ref_holder_tags) @@ -261,7 +263,7 @@ def make_distributed_recv(src_rank: int, comm_tag: CommTagType, dtype = np.dtype(dtype) return DistributedRecv( src_rank=src_rank, comm_tag=comm_tag, shape=shape, dtype=dtype, - tags=tags, axes=axes) + axes=axes, tags=(tags | _get_default_tags())) # }}}