diff --git a/examples/wave/wave-min-mpi.py b/examples/wave/wave-min-mpi.py index 6c56353bd..17bb9544d 100644 --- a/examples/wave/wave-min-mpi.py +++ b/examples/wave/wave-min-mpi.py @@ -49,7 +49,7 @@ class WaveTag: def main(ctx_factory, dim=2, order=4, visualize=False): comm = MPI.COMM_WORLD - num_parts = comm.Get_size() + num_parts = comm.size cl_ctx = cl.create_some_context() queue = cl.CommandQueue(cl_ctx) @@ -60,10 +60,10 @@ def main(ctx_factory, dim=2, order=4, visualize=False): force_device_scalars=True, ) - from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis - mesh_dist = MPIMeshDistributor(comm) + from meshmode.distributed import get_partition_by_pymetis, membership_list_to_map + from meshmode.mesh.processing import partition_mesh - if mesh_dist.is_mananger_rank(): + if comm.rank == 0: from meshmode.mesh.generation import generate_regular_rect_mesh mesh = generate_regular_rect_mesh( a=(-0.5,)*dim, @@ -72,14 +72,16 @@ def main(ctx_factory, dim=2, order=4, visualize=False): logger.info("%d elements", mesh.nelements) - part_per_element = get_partition_by_pymetis(mesh, num_parts) - - local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts) + part_id_to_part = partition_mesh(mesh, + membership_list_to_map( + get_partition_by_pymetis(mesh, num_parts))) + parts = [part_id_to_part[i] for i in range(num_parts)] + local_mesh = comm.scatter(parts) del mesh else: - local_mesh = mesh_dist.receive_mesh_part() + local_mesh = comm.scatter(None) dcoll = DiscretizationCollection(actx, local_mesh, order=order) diff --git a/examples/wave/wave-op-mpi.py b/examples/wave/wave-op-mpi.py index 8c23336d0..57e6a76de 100644 --- a/examples/wave/wave-op-mpi.py +++ b/examples/wave/wave-op-mpi.py @@ -184,7 +184,7 @@ def main(ctx_factory, dim=2, order=3, queue = cl.CommandQueue(cl_ctx) comm = MPI.COMM_WORLD - num_parts = comm.Get_size() + num_parts = comm.size from grudge.array_context import get_reasonable_array_context_class actx_class = get_reasonable_array_context_class(lazy=lazy, distributed=True) @@ -195,12 +195,12 @@ def main(ctx_factory, dim=2, order=3, allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)), force_device_scalars=True) - from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis - mesh_dist = MPIMeshDistributor(comm) + from meshmode.distributed import get_partition_by_pymetis, membership_list_to_map + from meshmode.mesh.processing import partition_mesh nel_1d = 16 - if mesh_dist.is_mananger_rank(): + if comm.rank == 0: if use_nonaffine_mesh: from meshmode.mesh.generation import generate_warped_rect_mesh # FIXME: *generate_warped_rect_mesh* in meshmode warps a @@ -218,14 +218,17 @@ def main(ctx_factory, dim=2, order=3, logger.info("%d elements", mesh.nelements) - part_per_element = get_partition_by_pymetis(mesh, num_parts) + part_id_to_part = partition_mesh(mesh, + membership_list_to_map( + get_partition_by_pymetis(mesh, num_parts))) + parts = [part_id_to_part[i] for i in range(num_parts)] - local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts) + local_mesh = comm.scatter(parts) del mesh else: - local_mesh = mesh_dist.receive_mesh_part() + local_mesh = comm.scatter(None) from meshmode.discretization.poly_element import \ QuadratureSimplexGroupFactory, \ diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index d95089b5a..4be934588 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -115,24 +115,26 @@ def _test_func_comparison_mpi_communication_entrypoint(actx): comm = actx.mpi_communicator - from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis + from meshmode.distributed import ( + get_partition_by_pymetis, membership_list_to_map) from meshmode.mesh import BTAG_ALL + from meshmode.mesh.processing import partition_mesh - num_parts = comm.Get_size() + num_parts = comm.size - mesh_dist = MPIMeshDistributor(comm) - - if mesh_dist.is_mananger_rank(): + if comm.rank == 0: from meshmode.mesh.generation import generate_regular_rect_mesh mesh = generate_regular_rect_mesh(a=(-1,)*2, b=(1,)*2, nelements_per_axis=(2,)*2) - part_per_element = get_partition_by_pymetis(mesh, num_parts) - - local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts) + part_id_to_part = partition_mesh(mesh, + membership_list_to_map( + get_partition_by_pymetis(mesh, num_parts))) + parts = [part_id_to_part[i] for i in range(num_parts)] + local_mesh = comm.scatter(parts) else: - local_mesh = mesh_dist.receive_mesh_part() + local_mesh = comm.scatter(None) dcoll = DiscretizationCollection(actx, local_mesh, order=5) @@ -188,28 +190,30 @@ def test_mpi_wave_op(actx_class, num_ranks): def _test_mpi_wave_op_entrypoint(actx, visualize=False): comm = actx.mpi_communicator - i_local_rank = comm.Get_rank() - num_parts = comm.Get_size() + num_parts = comm.size - from meshmode.distributed import MPIMeshDistributor, get_partition_by_pymetis - mesh_dist = MPIMeshDistributor(comm) + from meshmode.distributed import ( + get_partition_by_pymetis, membership_list_to_map) + from meshmode.mesh.processing import partition_mesh dim = 2 order = 4 - if mesh_dist.is_mananger_rank(): + if comm.rank == 0: from meshmode.mesh.generation import generate_regular_rect_mesh mesh = generate_regular_rect_mesh(a=(-0.5,)*dim, b=(0.5,)*dim, nelements_per_axis=(16,)*dim) - part_per_element = get_partition_by_pymetis(mesh, num_parts) - - local_mesh = mesh_dist.send_mesh_parts(mesh, part_per_element, num_parts) + part_id_to_part = partition_mesh(mesh, + membership_list_to_map( + get_partition_by_pymetis(mesh, num_parts))) + parts = [part_id_to_part[i] for i in range(num_parts)] + local_mesh = comm.scatter(parts) del mesh else: - local_mesh = mesh_dist.receive_mesh_part() + local_mesh = comm.scatter(None) dcoll = DiscretizationCollection(actx, local_mesh, order=order) @@ -270,7 +274,7 @@ def rhs(t, w): final_t = 4 nsteps = int(final_t/dt) - logger.info("[%04d] dt %.5e nsteps %4d", i_local_rank, dt, nsteps) + logger.info("[%04d] dt %.5e nsteps %4d", comm.rank, dt, nsteps) step = 0 @@ -308,7 +312,7 @@ def rhs(t, w): logmgr.tick_after() logmgr.close() - logger.info("Rank %d exiting", i_local_rank) + logger.info("Rank %d exiting", comm.rank) # }}}