From 605c14713266c81db909254f0ba9a55136ff918e Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 5 Apr 2023 13:09:18 -0500 Subject: [PATCH 01/15] NEW: Implement restrict/prolongation operator --- src/tike/ptycho/solvers/options.py | 22 ++++++++++++++++++++++ tests/ptycho/test_multigrid.py | 21 +++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/src/tike/ptycho/solvers/options.py b/src/tike/ptycho/solvers/options.py index 0f86c26c..f38e9eb8 100644 --- a/src/tike/ptycho/solvers/options.py +++ b/src/tike/ptycho/solvers/options.py @@ -49,6 +49,7 @@ class IterativeOptions(abc.ABC): """The number of epochs to consider for convergence monitoring. Set to any value less than 2 to disable.""" + @dataclasses.dataclass class AdamOptions(IterativeOptions): name: str = dataclasses.field(default='adam_grad', init=False) @@ -264,3 +265,24 @@ def _resize_fft(x: np.ndarray, f: float) -> np.ndarray: norm='ortho', axes=(-2, -1), ) + + +def _resize_mean(x: np.ndarray, f: float) -> np.ndarray: + """Use an averaging filter to resize/resample the last 2 dimensions of x""" + if f == 1: + return x + if f < 1: + new_shape = ( + *x.shape[:-2], + int(x.shape[-2] * f), + int(1.0 / f), + int(x.shape[-1] * f), + int(1.0 / f), + ) + return np.mean(x.reshape(new_shape), axis=(-1, -3)) + else: + return np.repeat( + np.repeat(x, repeats=f, axis=-2), + repeats=int(f), + axis=-1, + ) diff --git a/tests/ptycho/test_multigrid.py b/tests/ptycho/test_multigrid.py index c55e693e..fce3b1f5 100644 --- a/tests/ptycho/test_multigrid.py +++ b/tests/ptycho/test_multigrid.py @@ -14,6 +14,7 @@ _resize_cubic, _resize_lanczos, _resize_linear, + _resize_mean, ) from .templates import _mpi_size @@ -23,12 +24,32 @@ output_folder = os.path.join(result_dir, 'multigrid') +def test_resize_mean(): + x0 = np.array([[ + [0, 1], + [5, 7], + ]]) + x = np.array([[ + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [5, 5, 5, 7, 7, 7], + [5, 5, 5, 7, 7, 7], + [5, 5, 5, 7, 7, 7], + ]]) + x1 = _resize_mean(x0, 3.0) + np.testing.assert_equal(x1, x) + x2 = _resize_mean(x, 1.0/3.0) + np.testing.assert_equal(x2, x0) + + @pytest.mark.parametrize("function", [ _resize_fft, _resize_spline, _resize_linear, _resize_cubic, _resize_lanczos, + _resize_mean, ]) def test_resample(function, filename='siemens-star-small.npz.bz2'): From 97a1e4b7025e1ed3f9c21f0d4b453957de81afb3 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 6 Apr 2023 11:50:53 -0500 Subject: [PATCH 02/15] REF: Don't have a separate private ptycho parameters --- src/tike/ptycho/ptycho.py | 179 +++++++++++++++++++------------------- 1 file changed, 89 insertions(+), 90 deletions(-) diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index de904e8d..92298cfa 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -317,8 +317,7 @@ def __init__( mpi = tike.communicators.NoMPIComm self.data = data - self.parameters = parameters - self._device_parameters = copy.deepcopy(parameters) + self.parameters = copy.deepcopy(parameters) self.device = cp.cuda.Device( num_gpu[0] if isinstance(num_gpu, tuple) else None) self.operator = tike.operators.Ptycho( @@ -343,9 +342,9 @@ def __enter__(self): odd_pool = self.comm.pool.num_workers % 2 ( self.comm.order, - self._device_parameters.scan, + self.parameters.scan, self.data, - self._device_parameters.eigen_weights, + self.parameters.eigen_weights, ) = tike.cluster.by_scan_grid( self.comm.pool, ( @@ -356,55 +355,55 @@ def __enter__(self): (tike.precision.floating, tike.precision.floating if self.data.itemsize > 2 else self.data.dtype, tike.precision.floating), - self._device_parameters.scan, + self.parameters.scan, self.data, - self._device_parameters.eigen_weights, + self.parameters.eigen_weights, ) - self._device_parameters.psi = self.comm.pool.bcast( - [self._device_parameters.psi.astype(tike.precision.cfloating)]) + self.parameters.psi = self.comm.pool.bcast( + [self.parameters.psi.astype(tike.precision.cfloating)]) - self._device_parameters.probe = self.comm.pool.bcast( - [self._device_parameters.probe.astype(tike.precision.cfloating)]) + self.parameters.probe = self.comm.pool.bcast( + [self.parameters.probe.astype(tike.precision.cfloating)]) - if self._device_parameters.probe_options is not None: - self._device_parameters.probe_options = self._device_parameters.probe_options.copy_to_device( + if self.parameters.probe_options is not None: + self.parameters.probe_options = self.parameters.probe_options.copy_to_device( self.comm,) - if self._device_parameters.object_options is not None: - self._device_parameters.object_options = self._device_parameters.object_options.copy_to_device( + if self.parameters.object_options is not None: + self.parameters.object_options = self.parameters.object_options.copy_to_device( self.comm,) - if self._device_parameters.eigen_probe is not None: - self._device_parameters.eigen_probe = self.comm.pool.bcast([ - self._device_parameters.eigen_probe.astype( + if self.parameters.eigen_probe is not None: + self.parameters.eigen_probe = self.comm.pool.bcast([ + self.parameters.eigen_probe.astype( tike.precision.cfloating) ]) - if self._device_parameters.position_options is not None: + if self.parameters.position_options is not None: # TODO: Consider combining put/split, get/join operations? - self._device_parameters.position_options = self.comm.pool.map( + self.parameters.position_options = self.comm.pool.map( PositionOptions.copy_to_device, - (self._device_parameters.position_options.split(x) + (self.parameters.position_options.split(x) for x in self.comm.order), ) # Unique batch for each device self.batches = self.comm.pool.map( getattr(tike.cluster, - self._device_parameters.algorithm_options.batch_method), - self._device_parameters.scan, - num_cluster=self._device_parameters.algorithm_options.num_batch, + self.parameters.algorithm_options.batch_method), + self.parameters.scan, + num_cluster=self.parameters.algorithm_options.num_batch, ) - self._device_parameters.probe = _rescale_probe( + self.parameters.probe = _rescale_probe( self.operator, self.comm, self.data, - self._device_parameters.psi, - self._device_parameters.scan, - self._device_parameters.probe, - num_batch=self._device_parameters.algorithm_options.num_batch, + self.parameters.psi, + self.parameters.scan, + self.parameters.probe, + num_batch=self.parameters.algorithm_options.num_batch, ) return self @@ -415,108 +414,108 @@ def iterate(self, num_iter: int) -> None: for i in range(num_iter): logger.info( - f"{self._device_parameters.algorithm_options.name} epoch " - f"{len(self._device_parameters.algorithm_options.times):,d}") + f"{self.parameters.algorithm_options.name} epoch " + f"{len(self.parameters.algorithm_options.times):,d}") - if self._device_parameters.probe_options is not None: - if self._device_parameters.probe_options.force_centered_intensity: - self._device_parameters.probe = self.comm.pool.map( + if self.parameters.probe_options is not None: + if self.parameters.probe_options.force_centered_intensity: + self.parameters.probe = self.comm.pool.map( constrain_center_peak, - self._device_parameters.probe, + self.parameters.probe, ) - if self._device_parameters.probe_options.force_sparsity < 1: - self._device_parameters.probe = self.comm.pool.map( + if self.parameters.probe_options.force_sparsity < 1: + self.parameters.probe = self.comm.pool.map( constrain_probe_sparsity, - self._device_parameters.probe, - f=self._device_parameters.probe_options + self.parameters.probe, + f=self.parameters.probe_options .force_sparsity, ) - if self._device_parameters.probe_options.force_orthogonality: + if self.parameters.probe_options.force_orthogonality: ( - self._device_parameters.probe, + self.parameters.probe, power, ) = (list(a) for a in zip(*self.comm.pool.map( tike.ptycho.probe.orthogonalize_eig, - self._device_parameters.probe, + self.parameters.probe, ))) - self._device_parameters.probe_options.power.append(power[0].get()) + self.parameters.probe_options.power.append(power[0].get()) - self._device_parameters = getattr( + self.parameters = getattr( solvers, - self._device_parameters.algorithm_options.name, + self.parameters.algorithm_options.name, )( self.operator, self.comm, data=self.data, batches=self.batches, - parameters=self._device_parameters, + parameters=self.parameters, ) - if self._device_parameters.object_options.clip_magnitude: - self._device_parameters.psi = self.comm.pool.map( + if self.parameters.object_options.clip_magnitude: + self.parameters.psi = self.comm.pool.map( _clip_magnitude, - self._device_parameters.psi, + self.parameters.psi, a_max=1.0, ) - if (self._device_parameters.position_options - and self._device_parameters.position_options[0] + if (self.parameters.position_options + and self.parameters.position_options[0] .use_position_regularization): - (self._device_parameters.position_options + (self.parameters.position_options ) = affine_position_regularization( self.comm, - updated=self._device_parameters.scan, - position_options=self._device_parameters.position_options, + updated=self.parameters.scan, + position_options=self.parameters.position_options, ) - self._device_parameters.algorithm_options.times.append( + self.parameters.algorithm_options.times.append( time.perf_counter() - start) start = time.perf_counter() - if tike.opt.is_converged(self._device_parameters.algorithm_options): + if tike.opt.is_converged(self.parameters.algorithm_options): break def _get_result(self): """Return the current parameter estimates.""" - self.parameters.probe = self._device_parameters.probe[0].get() + self.parameters.probe = self.parameters.probe[0].get() - self.parameters.psi = self._device_parameters.psi[0].get() + self.parameters.psi = self.parameters.psi[0].get() reorder = np.argsort(np.concatenate(self.comm.order)) self.parameters.scan = self.comm.pool.gather_host( - self._device_parameters.scan, + self.parameters.scan, axis=-2, )[reorder] - if self._device_parameters.eigen_probe is not None: - self.parameters.eigen_probe = self._device_parameters.eigen_probe[ + if self.parameters.eigen_probe is not None: + self.parameters.eigen_probe = self.parameters.eigen_probe[ 0].get() - if self._device_parameters.eigen_weights is not None: + if self.parameters.eigen_weights is not None: self.parameters.eigen_weights = self.comm.pool.gather( - self._device_parameters.eigen_weights, + self.parameters.eigen_weights, axis=-3, )[reorder].get() - self.parameters.algorithm_options = self._device_parameters.algorithm_options + self.parameters.algorithm_options = self.parameters.algorithm_options - if self._device_parameters.probe_options is not None: - self.parameters.probe_options = self._device_parameters.probe_options.copy_to_host( + if self.parameters.probe_options is not None: + self.parameters.probe_options = self.parameters.probe_options.copy_to_host( ) - if self._device_parameters.object_options is not None: - self.parameters.object_options = self._device_parameters.object_options.copy_to_host( + if self.parameters.object_options is not None: + self.parameters.object_options = self.parameters.object_options.copy_to_host( ) - if self._device_parameters.position_options is not None: - host_position_options = self._device_parameters.position_options[ + if self.parameters.position_options is not None: + host_position_options = self.parameters.position_options[ 0].empty() for x, o in zip( self.comm.pool.map( PositionOptions.copy_to_host, - self._device_parameters.position_options, + self.parameters.position_options, ), self.comm.order, ): @@ -534,29 +533,29 @@ def get_convergence( ) -> typing.Tuple[typing.List[typing.List[float]], typing.List[float]]: """Return the cost function values and times as a tuple.""" return ( - self._device_parameters.algorithm_options.costs, - self._device_parameters.algorithm_options.times, + self.parameters.algorithm_options.costs, + self.parameters.algorithm_options.times, ) def get_psi(self) -> np.array: """Return the current object estimate as a numpy array.""" - return self._device_parameters.psi[0].get() + return self.parameters.psi[0].get() def get_probe(self) -> typing.Tuple[np.array, np.array, np.array]: """Return the current probe, eigen_probe, weights as numpy arrays.""" reorder = np.argsort(np.concatenate(self.comm.order)) - if self._device_parameters.eigen_probe is None: + if self.parameters.eigen_probe is None: eigen_probe = None else: - eigen_probe = self._device_parameters.eigen_probe[0].get() - if self._device_parameters.eigen_weights is None: + eigen_probe = self.parameters.eigen_probe[0].get() + if self.parameters.eigen_weights is None: eigen_weights = None else: eigen_weights = self.comm.pool.gather( - self._device_parameters.eigen_weights, + self.parameters.eigen_weights, axis=-3, )[reorder].get() - probe = self._device_parameters.probe[0].get() + probe = self.parameters.probe[0].get() return probe, eigen_probe, eigen_weights def peek(self) -> typing.Tuple[np.array, np.array, np.array, np.array]: @@ -592,7 +591,7 @@ def append_new_data( if odd_pool else self.comm.pool.num_workers // 2, 1 if odd_pool else 2, ), - (self._device_parameters.scan[0].dtype, self.data[0].dtype), + (self.parameters.scan[0].dtype, self.data[0].dtype), new_scan, new_data, ) @@ -604,9 +603,9 @@ def append_new_data( new_data, axis=0, ) - self._device_parameters.scan = self.comm.pool.map( + self.parameters.scan = self.comm.pool.map( cp.append, - self._device_parameters.scan, + self.parameters.scan, new_scan, axis=0, ) @@ -619,15 +618,15 @@ def append_new_data( # Rebatch on each device self.batches = self.comm.pool.map( getattr(tike.cluster, - self._device_parameters.algorithm_options.batch_method), - self._device_parameters.scan, - num_cluster=self._device_parameters.algorithm_options.num_batch, + self.parameters.algorithm_options.batch_method), + self.parameters.scan, + num_cluster=self.parameters.algorithm_options.num_batch, ) - if self._device_parameters.eigen_weights is not None: - self._device_parameters.eigen_weights = self.comm.pool.map( + if self.parameters.eigen_weights is not None: + self.parameters.eigen_weights = self.comm.pool.map( cp.pad, - self._device_parameters.eigen_weights, + self.parameters.eigen_weights, pad_width=( (0, len(new_scan)), # position (0, 0), # eigen @@ -636,10 +635,10 @@ def append_new_data( mode='mean', ) - if self._device_parameters.position_options is not None: - self._device_parameters.position_options = self.comm.pool.map( + if self.parameters.position_options is not None: + self.parameters.position_options = self.comm.pool.map( PositionOptions.append, - self._device_parameters.position_options, + self.parameters.position_options, new_scan, ) From 0a56af7b2e5dade64cf529a069903c14deb2f3f4 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 6 Apr 2023 15:41:07 -0500 Subject: [PATCH 03/15] NEW: Get full result from Ptycho context --- src/tike/ptycho/ptycho.py | 39 +++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 92298cfa..38724b58 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -477,41 +477,38 @@ def iterate(self, num_iter: int) -> None: if tike.opt.is_converged(self.parameters.algorithm_options): break - def _get_result(self): + def get_result(self): """Return the current parameter estimates.""" - self.parameters.probe = self.parameters.probe[0].get() - - self.parameters.psi = self.parameters.psi[0].get() - reorder = np.argsort(np.concatenate(self.comm.order)) - self.parameters.scan = self.comm.pool.gather_host( - self.parameters.scan, - axis=-2, - )[reorder] + parameters = solvers.PtychoParameters( + probe=self.parameters.probe[0].get(), + psi=self.parameters.psi[0].get(), + scan=self.comm.pool.gather_host( + self.parameters.scan, + axis=-2, + )[reorder], + algorithm_options=self.parameters.algorithm_options, + ) if self.parameters.eigen_probe is not None: - self.parameters.eigen_probe = self.parameters.eigen_probe[ - 0].get() + parameters.eigen_probe = self.parameters.eigen_probe[0].get() if self.parameters.eigen_weights is not None: - self.parameters.eigen_weights = self.comm.pool.gather( + parameters.eigen_weights = self.comm.pool.gather( self.parameters.eigen_weights, axis=-3, )[reorder].get() - self.parameters.algorithm_options = self.parameters.algorithm_options - if self.parameters.probe_options is not None: - self.parameters.probe_options = self.parameters.probe_options.copy_to_host( + parameters.probe_options = self.parameters.probe_options.copy_to_host( ) if self.parameters.object_options is not None: - self.parameters.object_options = self.parameters.object_options.copy_to_host( + parameters.object_options = self.parameters.object_options.copy_to_host( ) if self.parameters.position_options is not None: - host_position_options = self.parameters.position_options[ - 0].empty() + host_position_options = self.parameters.position_options[0].empty() for x, o in zip( self.comm.pool.map( PositionOptions.copy_to_host, @@ -520,10 +517,12 @@ def _get_result(self): self.comm.order, ): host_position_options = host_position_options.join(x, o) - self.parameters.position_options = host_position_options + parameters.position_options = host_position_options + + return parameters def __exit__(self, type, value, traceback): - self._get_result() + self.parameters = self.get_result() self.comm.__exit__(type, value, traceback) self.operator.__exit__(type, value, traceback) self.device.__exit__(type, value, traceback) From 39275ae3a464d31972533abff78c730e4136490f Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 18 Apr 2023 11:59:20 -0500 Subject: [PATCH 04/15] REF: Return a copy for copy methods instead of self reference --- src/tike/ptycho/object.py | 23 +++++++++++++++-------- src/tike/ptycho/position.py | 19 +++++++++++-------- src/tike/ptycho/probe.py | 23 +++++++++++++++-------- 3 files changed, 41 insertions(+), 24 deletions(-) diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index 415f6aae..a0a4f77b 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -8,6 +8,7 @@ import dataclasses import logging import typing +import copy import cupy as cp import cupyx.scipy.ndimage @@ -71,23 +72,29 @@ class ObjectOptions: def copy_to_device(self, comm): """Copy to the current GPU memory.""" + options = copy.copy(self) if self.v is not None: - self.v = cp.asarray(self.v) + options.v = cp.asarray(self.v) if self.m is not None: - self.m = cp.asarray(self.m) + options.m = cp.asarray(self.m) if self.preconditioner is not None: - self.preconditioner = comm.pool.bcast([self.preconditioner]) - return self + options.preconditioner = comm.pool.bcast([self.preconditioner]) + if self.multigrid_update is not None: + options.multigrid_update = cp.asarray(self.multigrid_update) + return options def copy_to_host(self): """Copy to the host CPU memory.""" + options = copy.copy(self) if self.v is not None: - self.v = cp.asnumpy(self.v) + options.v = cp.asnumpy(self.v) if self.m is not None: - self.m = cp.asnumpy(self.m) + options.m = cp.asnumpy(self.m) if self.preconditioner is not None: - self.preconditioner = cp.asnumpy(self.preconditioner[0]) - return self + options.preconditioner = cp.asnumpy(self.preconditioner[0]) + if self.multigrid_update is not None: + options.multigrid_update = cp.asnumpy(self.multigrid_update) + return options def resample(self, factor: float) -> ObjectOptions: """Return a new `ObjectOptions` with the parameters rescaled.""" diff --git a/src/tike/ptycho/position.py b/src/tike/ptycho/position.py index 937b8254..c8b7e084 100644 --- a/src/tike/ptycho/position.py +++ b/src/tike/ptycho/position.py @@ -119,6 +119,7 @@ import dataclasses import logging import typing +import copy import cupy as cp import cupyx.scipy.ndimage @@ -445,21 +446,23 @@ def join(self, other, indices): def copy_to_device(self): """Copy to the current GPU memory.""" - self.initial_scan = cp.asarray(self.initial_scan) + options = copy.copy(self) + options.initial_scan = cp.asarray(self.initial_scan) if self.confidence is not None: - self.confidence = cp.asarray(self.confidence) + options.confidence = cp.asarray(self.confidence) if self.use_adaptive_moment: - self._momentum = cp.asarray(self._momentum) - return self + options._momentum = cp.asarray(self._momentum) + return options def copy_to_host(self): """Copy to the host CPU memory.""" - self.initial_scan = cp.asnumpy(self.initial_scan) + options = copy.copy(self) + options.initial_scan = cp.asnumpy(self.initial_scan) if self.confidence is not None: - self.confidence = cp.asnumpy(self.confidence) + options.confidence = cp.asnumpy(self.confidence) if self.use_adaptive_moment: - self._momentum = cp.asnumpy(self._momentum) - return self + options._momentum = cp.asnumpy(self._momentum) + return options def resample(self, factor: float) -> PositionOptions: """Return a new `PositionOptions` with the parameters scaled.""" diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index ef2d64c7..9b6034d3 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -36,6 +36,7 @@ """ from __future__ import annotations +import copy import dataclasses import logging import typing @@ -137,23 +138,29 @@ class ProbeOptions: def copy_to_device(self, comm): """Copy to the current GPU memory.""" + options = copy.copy(self) if self.v is not None: - self.v = cp.asarray(self.v) + options.v = cp.asarray(self.v) if self.m is not None: - self.m = cp.asarray(self.m) + options.m = cp.asarray(self.m) if self.preconditioner is not None: - self.preconditioner = comm.pool.bcast([self.preconditioner]) - return self + options.preconditioner = comm.pool.bcast([self.preconditioner]) + if self.multigrid_update is not None: + options.multigrid_update = cp.asarray(self.multigrid_update) + return options def copy_to_host(self): """Copy to the host CPU memory.""" + options = copy.copy(self) if self.v is not None: - self.v = cp.asnumpy(self.v) + options.v = cp.asnumpy(self.v) if self.m is not None: - self.m = cp.asnumpy(self.m) + options.m = cp.asnumpy(self.m) if self.preconditioner is not None: - self.preconditioner = cp.asnumpy(self.preconditioner[0]) - return self + options.preconditioner = cp.asnumpy(self.preconditioner[0]) + if self.multigrid_update is not None: + options.multigrid_update = cp.asnumpy(self.multigrid_update) + return options def resample(self, factor: float) -> ProbeOptions: """Return a new `ProbeOptions` with the parameters rescaled.""" From 1befa1ab99ac358a6235f25f0c49b3de49e6a7ad Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 18 Apr 2023 12:00:35 -0500 Subject: [PATCH 05/15] NEW: Add a multigrid parameter to options classes --- src/tike/ptycho/object.py | 13 +++++++++++-- src/tike/ptycho/probe.py | 13 +++++++++++-- src/tike/ptycho/solvers/options.py | 4 ++-- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index a0a4f77b..b6a00e1d 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -67,6 +67,12 @@ class ObjectOptions: ) """Used for compact batch updates.""" + multigrid_update: typing.Union[npt.NDArray, None] = dataclasses.field( + init=False, + default_factory=lambda: None, + ) + """Used for multigrid updates.""" + clip_magnitude: bool = False """Whether to force the object magnitude to remain <= 1.""" @@ -96,9 +102,9 @@ def copy_to_host(self): options.multigrid_update = cp.asnumpy(self.multigrid_update) return options - def resample(self, factor: float) -> ObjectOptions: + def resample(self, factor: float, interp) -> ObjectOptions: """Return a new `ObjectOptions` with the parameters rescaled.""" - return ObjectOptions( + options = ObjectOptions( positivity_constraint=self.positivity_constraint, smoothness_constraint=self.smoothness_constraint, use_adaptive_moment=self.use_adaptive_moment, @@ -106,6 +112,9 @@ def resample(self, factor: float) -> ObjectOptions: mdecay=self.mdecay, clip_magnitude=self.clip_magnitude, ) + if self.multigrid_update is not None: + options.multigrid_update = interp(self.multigrid_update, factor) + return options # Momentum reset to zero when grid scale changes diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index 9b6034d3..60a507d4 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -136,6 +136,12 @@ class ProbeOptions: ) """The power of the primary probe modes at each iteration.""" + multigrid_update: typing.Union[npt.NDArray, None] = dataclasses.field( + init=False, + default_factory=lambda: None, + ) + """Used for multigrid updates.""" + def copy_to_device(self, comm): """Copy to the current GPU memory.""" options = copy.copy(self) @@ -162,9 +168,9 @@ def copy_to_host(self): options.multigrid_update = cp.asnumpy(self.multigrid_update) return options - def resample(self, factor: float) -> ProbeOptions: + def resample(self, factor: float, interp) -> ProbeOptions: """Return a new `ProbeOptions` with the parameters rescaled.""" - return ProbeOptions( + options = ProbeOptions( force_orthogonality=self.force_orthogonality, force_centered_intensity=self.force_centered_intensity, force_sparsity=self.force_sparsity, @@ -175,6 +181,9 @@ def resample(self, factor: float) -> ProbeOptions: probe_support_degree=self.probe_support_degree, probe_support_radius=self.probe_support_radius, ) + if self.multigrid_update is not None: + options.multigrid_update = interp(self.multigrid_update, factor) + return options # Momentum reset to zero when grid scale changes diff --git a/src/tike/ptycho/solvers/options.py b/src/tike/ptycho/solvers/options.py index f38e9eb8..0529ed07 100644 --- a/src/tike/ptycho/solvers/options.py +++ b/src/tike/ptycho/solvers/options.py @@ -179,9 +179,9 @@ def resample( if self.eigen_probe is not None else None, eigen_weights=self.eigen_weights, algorithm_options=self.algorithm_options, - probe_options=self.probe_options.resample(factor) + probe_options=self.probe_options.resample(factor, interp) if self.probe_options is not None else None, - object_options=self.object_options.resample(factor) + object_options=self.object_options.resample(factor, interp) if self.object_options is not None else None, position_options=self.position_options.resample(factor) if self.position_options is not None else None, From 3d9eff91ad7d11762bd75663c599a3dbbd6cde21 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 27 Apr 2023 18:27:47 -0500 Subject: [PATCH 06/15] NEW: Add gradient based convergence criteria --- src/tike/ptycho/object.py | 13 +++++++++++++ src/tike/ptycho/ptycho.py | 14 +++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index b6a00e1d..e149c392 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -26,6 +26,16 @@ class ObjectOptions: """Manage data and setting related to object correction.""" + convergence_tolerance: float = 0 + """Terminate reconstruction early when the mnorm of the object update is + less than this value.""" + + update_mnorm: typing.List[float] = dataclasses.field( + init=False, + default_factory=list, + ) + """A record of the previous mnorms of the object update.""" + positivity_constraint: float = 0 """This value is passed to the tike.ptycho.object.positivity_constraint function.""" @@ -79,6 +89,7 @@ class ObjectOptions: def copy_to_device(self, comm): """Copy to the current GPU memory.""" options = copy.copy(self) + options.update_mnorm = copy.copy(self.update_mnorm) if self.v is not None: options.v = cp.asarray(self.v) if self.m is not None: @@ -92,6 +103,7 @@ def copy_to_device(self, comm): def copy_to_host(self): """Copy to the host CPU memory.""" options = copy.copy(self) + options.update_mnorm = copy.copy(self.update_mnorm) if self.v is not None: options.v = cp.asnumpy(self.v) if self.m is not None: @@ -112,6 +124,7 @@ def resample(self, factor: float, interp) -> ObjectOptions: mdecay=self.mdecay, clip_magnitude=self.clip_magnitude, ) + options.update_mnorm = copy.copy(self.update_mnorm) if self.multigrid_update is not None: options.multigrid_update = interp(self.multigrid_update, factor) return options diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 38724b58..747aaa5e 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -411,11 +411,11 @@ def __enter__(self): def iterate(self, num_iter: int) -> None: """Advance the reconstruction by num_iter epochs.""" start = time.perf_counter() + psi_previous = None for i in range(num_iter): - logger.info( - f"{self.parameters.algorithm_options.name} epoch " - f"{len(self.parameters.algorithm_options.times):,d}") + logger.info(f"{self.parameters.algorithm_options.name} epoch " + f"{len(self.parameters.algorithm_options.times):,d}") if self.parameters.probe_options is not None: if self.parameters.probe_options.force_centered_intensity: @@ -476,6 +476,14 @@ def iterate(self, num_iter: int) -> None: if tike.opt.is_converged(self.parameters.algorithm_options): break + if psi_previous is not None: + update_norm = tike.linalg.mnorm(self.parameters.psi[0] - psi_previous) + self.parameters.object_options.update_mnorm.append(update_norm.get()) + logger.info(f"The object update mean-norm is {update_norm:.3e}") + if update_norm < self.parameters.object_options.convergence_tolerance: + logger.info(f"The object seems converged. {update_norm:.3e} < {self.parameters.object_options.convergence_tolerance:.3e}") + break + psi_previous = cp.copy(self.parameters.psi[0]) def get_result(self): """Return the current parameter estimates.""" From 4da5bfc57775abc57ea4d29c625bb73165d7bed0 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Mon, 1 May 2023 19:12:42 -0500 Subject: [PATCH 07/15] BUG: Add missing object option in resample --- src/tike/ptycho/object.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index e149c392..9c1a68d0 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -117,6 +117,7 @@ def copy_to_host(self): def resample(self, factor: float, interp) -> ObjectOptions: """Return a new `ObjectOptions` with the parameters rescaled.""" options = ObjectOptions( + convergence_tolerance=self.convergence_tolerance, positivity_constraint=self.positivity_constraint, smoothness_constraint=self.smoothness_constraint, use_adaptive_moment=self.use_adaptive_moment, From 36f1239c7abb58f5ab6d361997dfff50cd8a3a02 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 2 May 2023 11:54:44 -0500 Subject: [PATCH 08/15] NEW: increase step size in line_search automatically --- src/tike/opt.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/tike/opt.py b/src/tike/opt.py index 7125e899..1a921dd2 100644 --- a/src/tike/opt.py +++ b/src/tike/opt.py @@ -259,12 +259,18 @@ def line_search( # Decrease the step length while the step increases the cost function step_count = 0 first_step = step_length + step_is_decreasing = False while True: xsd = update_multi(x, step_length, d) fxsd = f(xsd) if fxsd <= fx + step_shrink * m: - break - step_length *= step_shrink + if step_is_decreasing: + break + step_length /= step_shrink + else: + step_length *= step_shrink + step_is_decreasing = True + if step_length < 1e-32: warnings.warn("Line search failed for conjugate gradient.") step_length, fxsd, xsd = 0, fx, x From 1e06e9fa0e262586dab92739b19bf52cef6effe6 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 2 May 2023 11:55:59 -0500 Subject: [PATCH 09/15] BUG: Using wrong decay parameters for probe ADAM --- src/tike/ptycho/solvers/rpie.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tike/ptycho/solvers/rpie.py b/src/tike/ptycho/solvers/rpie.py index ba646464..9eaa1420 100644 --- a/src/tike/ptycho/solvers/rpie.py +++ b/src/tike/ptycho/solvers/rpie.py @@ -398,8 +398,8 @@ def _update( g=(dprobe)[0, 0, mode, :, :], v=probe_options.v, m=probe_options.m, - vdecay=object_options.vdecay, - mdecay=object_options.mdecay, + vdecay=probe_options.vdecay, + mdecay=probe_options.mdecay, ) probe[0] = probe[0] + dprobe / deno probe = comm.pool.bcast([probe[0]]) From 5b313eddbb1345f10a9fcf4d3584ffbd4d596aa3 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 2 May 2023 11:57:29 -0500 Subject: [PATCH 10/15] BUG: Use correct mean restriction operator --- src/tike/ptycho/solvers/options.py | 4 ++-- tests/ptycho/test_multigrid.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/tike/ptycho/solvers/options.py b/src/tike/ptycho/solvers/options.py index 0529ed07..c8589d15 100644 --- a/src/tike/ptycho/solvers/options.py +++ b/src/tike/ptycho/solvers/options.py @@ -279,10 +279,10 @@ def _resize_mean(x: np.ndarray, f: float) -> np.ndarray: int(x.shape[-1] * f), int(1.0 / f), ) - return np.mean(x.reshape(new_shape), axis=(-1, -3)) + return np.sum(x.reshape(new_shape), axis=(-1, -3)) * (f * f) else: return np.repeat( np.repeat(x, repeats=f, axis=-2), repeats=int(f), axis=-1, - ) + ) * (f * f) diff --git a/tests/ptycho/test_multigrid.py b/tests/ptycho/test_multigrid.py index fce3b1f5..617ddbd8 100644 --- a/tests/ptycho/test_multigrid.py +++ b/tests/ptycho/test_multigrid.py @@ -29,6 +29,10 @@ def test_resize_mean(): [0, 1], [5, 7], ]]) + x3 = np.array([[ + [0, 1*9], + [5*9, 7*9], + ]]) x = np.array([[ [0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1], @@ -36,11 +40,11 @@ def test_resize_mean(): [5, 5, 5, 7, 7, 7], [5, 5, 5, 7, 7, 7], [5, 5, 5, 7, 7, 7], - ]]) + ]]) * 9 x1 = _resize_mean(x0, 3.0) np.testing.assert_equal(x1, x) x2 = _resize_mean(x, 1.0/3.0) - np.testing.assert_equal(x2, x0) + np.testing.assert_equal(x2, x3) @pytest.mark.parametrize("function", [ From afe217a3266cf577e5beb1c8d967ccdc4d4f3041 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 2 May 2023 12:03:54 -0500 Subject: [PATCH 11/15] NEW: Add new multigrid method --- src/tike/ptycho/ptycho.py | 183 ++++++++++++++++++++++++++++++- src/tike/ptycho/solvers/adam.py | 4 + src/tike/ptycho/solvers/dm.py | 4 + src/tike/ptycho/solvers/lstsq.py | 7 ++ src/tike/ptycho/solvers/rpie.py | 4 + 5 files changed, 201 insertions(+), 1 deletion(-) diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 747aaa5e..54a3bcdd 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -54,6 +54,7 @@ "simulate", "Reconstruction", "reconstruct_multigrid", + "reconstruct_multigrid_new", ] import copy @@ -715,7 +716,7 @@ def reconstruct_multigrid( num_gpu: typing.Union[int, typing.Tuple[int, ...]] = 1, use_mpi: bool = False, num_levels: int = 3, - interp=None, + interp: typing.Callable = solvers.options._resize_fft, ) -> solvers.PtychoParameters: """Solve the ptychography problem using a multi-grid method. @@ -764,3 +765,183 @@ def reconstruct_multigrid( resampled_parameters = context.parameters.resample(2.0, interp) raise RuntimeError('This should not happen.') + + +def reconstruct_multigrid_new( + data: npt.NDArray, + parameters: solvers.PtychoParameters, + model: str = 'gaussian', + num_gpu: typing.Union[int, typing.Tuple[int, ...]] = 1, + use_mpi: bool = False, + num_levels: int = 3, + level: int = 0, + interp: typing.Callable = solvers.options._resize_mean, +) -> solvers.PtychoParameters: + """Solve the ptychography problem using a multi-grid method. + + .. versionadded:: 0.23.2 + + Uses the same parameters as the functional reconstruct API. This function + applies a multi-grid approach to the problem by downsampling the real-space + input parameters and cropping the diffraction patterns to reduce the + computational cost of early iterations. + + Parameters + ---------- + num_levels : int > 0 + The number of times to reduce the problem by a factor of two. + + + .. seealso:: :py:func:`tike.ptycho.ptycho.reconstruct` + """ + + if level == 0 and (data.shape[-1] * 0.5**(num_levels - 1) < 64): + warnings.warn('Cropping diffraction patterns to less than 64 pixels ' + 'wide is not recommended because the full doughnut' + ' may not be visible.') + + with tike.ptycho.Reconstruction( + data=data, + parameters=parameters, + model=model, + num_gpu=num_gpu, + use_mpi=use_mpi, + ) as context: + + if context.parameters.object_options.multigrid_update is not None: + grad_psi = context.comm.pool.map( + context.operator.grad_psi, + context.data, + context.parameters.psi, + context.parameters.scan, + context.parameters.probe, + ) + grad_psi = context.comm.Allreduce_reduce_gpu(grad_psi)[0] + context.parameters.object_options.multigrid_update += -grad_psi + + if context.parameters.probe_options.multigrid_update is not None: + grad_probe = context.comm.pool.map( + context.operator.grad_probe, + context.data, + context.parameters.psi, + context.parameters.scan, + context.parameters.probe, + ) + grad_probe = context.comm.Allreduce_reduce_gpu(grad_probe)[0] + context.parameters.probe_options.multigrid_update += -grad_probe + + logging.info(f'Multigrid level {level} pre-smoothing') + + # pre-smoothing + context.iterate(4) + + if level + 1 < num_levels: + + # coarse-grid correction + parameters_coarser = context.get_result().resample(0.5, interp) + data_coarser = solvers.crop_fourier_space( + data, + data.shape[-1] // 2, + ) + + grad_psi = context.comm.pool.map( + context.operator.grad_psi, + context.data, + context.parameters.psi, + context.parameters.scan, + context.parameters.probe, + ) + grad_psi = context.comm.Allreduce_reduce_cpu(grad_psi) + if context.parameters.object_options.multigrid_update is None: + parameters_coarser.object_options.multigrid_update = interp(grad_psi, 0.5) + else: + parameters_coarser.object_options.multigrid_update += interp(grad_psi, 0.5) + + grad_probe = context.comm.pool.map( + context.operator.grad_probe, + context.data, + context.parameters.psi, + context.parameters.scan, + context.parameters.probe, + ) + grad_probe = context.comm.Allreduce_reduce_cpu(grad_probe) + if context.parameters.probe_options.multigrid_update is None: + parameters_coarser.probe_options.multigrid_update = interp(grad_probe, 0.5) + else: + parameters_coarser.probe_options.multigrid_update += interp(grad_probe, 0.5) + + parameters_coarser_updated = reconstruct_multigrid_new( + data=data_coarser, + parameters=parameters_coarser, + num_gpu=num_gpu, + model=model, + use_mpi=use_mpi, + num_levels=num_levels, + level=level + 1, + interp=interp, + ) + + context.parameters.algorithm_options.times = parameters_coarser_updated.algorithm_options.times + context.parameters.algorithm_options.costs = parameters_coarser_updated.algorithm_options.costs + + def update_multi(x, gamma, dir): + + def f(x, dir): + return x + gamma * dir + + return list(context.comm.pool.map(f, x, dir)) + + def cost_function_psi(psi, **kwargs): + cost_out = context.comm.pool.map( + context.operator.cost, + context.data, + psi, + context.parameters.scan, + context.parameters.probe, + ) + return context.comm.Allreduce_mean(cost_out, axis=None).get() + + def cost_function_probe(probe, **kwargs): + cost_out = context.comm.pool.map( + context.operator.cost, + context.data, + context.parameters.psi, + context.parameters.scan, + probe, + ) + return context.comm.Allreduce_mean(cost_out, axis=None).get() + + logging.info(f'Multigrid level {level} upsample update') + + _, _, context.parameters.psi = tike.opt.line_search( + f=cost_function_psi, + x=context.parameters.psi, + d=context.comm.pool.bcast([ + interp( + parameters_coarser_updated.psi - parameters_coarser.psi, + 2.0, + ) + ]), + update_multi=update_multi, + ) + + _, _, context.parameters.probe = tike.opt.line_search( + f=cost_function_probe, + x=context.parameters.probe, + d=context.comm.pool.bcast([ + interp( + parameters_coarser_updated.probe - + parameters_coarser.probe, + 2.0, + ) + ]), + update_multi=update_multi, + ) + + logging.info(f'Multigrid level {level} post-smoothing') + + # post-smoothing + context.iterate(parameters.algorithm_options.num_iter) + + print(f"Return level {level}") + return context.parameters diff --git a/src/tike/ptycho/solvers/adam.py b/src/tike/ptycho/solvers/adam.py index feab4135..9fa6b37f 100644 --- a/src/tike/ptycho/solvers/adam.py +++ b/src/tike/ptycho/solvers/adam.py @@ -224,6 +224,8 @@ def _update_all( mdecay=object_options.mdecay, ) psi[0] = psi[0] - algorithm_options.step_length * dpsi / deno + if object_options.multigrid_update is not None: + psi[0] = psi[0] + object_options.multigrid_update / deno psi = comm.pool.bcast([psi[0]]) if probe_options: @@ -247,6 +249,8 @@ def _update_all( mdecay=object_options.mdecay, ) probe[0] = probe[0] - algorithm_options.step_length * dprobe + if probe_options.multigrid_update is not None: + probe[0] = probe[0] + probe_options.multigrid_update / deno probe = comm.pool.bcast([probe[0]]) return psi, probe diff --git a/src/tike/ptycho/solvers/dm.py b/src/tike/ptycho/solvers/dm.py index b8932088..a962ab60 100644 --- a/src/tike/ptycho/solvers/dm.py +++ b/src/tike/ptycho/solvers/dm.py @@ -279,6 +279,8 @@ def _apply_update( mdecay=object_options.mdecay, ) new_psi = dpsi + psi[0] + if object_options.multigrid_update is not None: + new_psi = new_psi + object_options.multigrid_update psi = comm.pool.bcast([new_psi]) if recover_probe: @@ -301,6 +303,8 @@ def _apply_update( mdecay=probe_options.mdecay, ) new_probe = dprobe + probe[0] + if probe_options.multigrid_update is not None: + new_probe = new_probe + probe_options.multigrid_update probe = comm.pool.bcast([new_probe]) return psi, probe diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index eebbfd9b..ff121b38 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -236,6 +236,9 @@ def lstsq_grad( dpsi = beta_object * object_update_precond psi[0] = psi[0] + dpsi + if object_options.multigrid_update is not None: + psi[0] = psi[0] + object_options.multigrid_update + if object_options.use_adaptive_moment: ( dpsi, @@ -557,6 +560,10 @@ def _update_nearplane( mdecay=object_options.mdecay, ) psi[0] = psi[0] + dpsi + + if object_options.multigrid_update is not None: + psi[0] = psi[0] + object_options.multigrid_update + psi = comm.pool.bcast([psi[0]]) else: object_options.combined_update += object_upd_sum[0] diff --git a/src/tike/ptycho/solvers/rpie.py b/src/tike/ptycho/solvers/rpie.py index 9eaa1420..e42b0ad9 100644 --- a/src/tike/ptycho/solvers/rpie.py +++ b/src/tike/ptycho/solvers/rpie.py @@ -354,6 +354,8 @@ def _update( mdecay=object_options.mdecay, ) psi[0] = psi[0] + dpsi / deno + if object_options.multigrid_update is not None: + psi[0] = psi[0] + object_options.multigrid_update / deno psi = comm.pool.bcast([psi[0]]) if probe_options: @@ -402,6 +404,8 @@ def _update( mdecay=probe_options.mdecay, ) probe[0] = probe[0] + dprobe / deno + if probe_options.multigrid_update is not None: + probe[0] = probe[0] + probe_options.multigrid_update / deno probe = comm.pool.bcast([probe[0]]) return psi, probe From 66e9fa4043c9c710137939deb19821b5521626b5 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 2 May 2023 15:24:32 -0500 Subject: [PATCH 12/15] REF: Smooth the convergence criteria --- src/tike/ptycho/ptycho.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 54a3bcdd..8f414b70 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -412,7 +412,7 @@ def __enter__(self): def iterate(self, num_iter: int) -> None: """Advance the reconstruction by num_iter epochs.""" start = time.perf_counter() - psi_previous = None + psi_previous = self.parameters.psi[0].copy() for i in range(num_iter): logger.info(f"{self.parameters.algorithm_options.name} epoch " @@ -460,9 +460,8 @@ def iterate(self, num_iter: int) -> None: a_max=1.0, ) - if (self.parameters.position_options - and self.parameters.position_options[0] - .use_position_regularization): + if (self.parameters.position_options and self.parameters + .position_options[0].use_position_regularization): (self.parameters.position_options ) = affine_position_regularization( @@ -471,20 +470,22 @@ def iterate(self, num_iter: int) -> None: position_options=self.parameters.position_options, ) - self.parameters.algorithm_options.times.append( - time.perf_counter() - start) + self.parameters.algorithm_options.times.append(time.perf_counter() - + start) start = time.perf_counter() - if tike.opt.is_converged(self.parameters.algorithm_options): + update_norm = tike.linalg.mnorm(self.parameters.psi[0] - + psi_previous) + self.parameters.object_options.update_mnorm.append( + update_norm.get()) + logger.info(f"The object update mean-norm is {update_norm:.3e}") + if (np.mean(self.parameters.object_options.update_mnorm[-5:]) < + self.parameters.object_options.convergence_tolerance): + logger.info( + f"The object seems converged. {update_norm:.3e} < " + f"{self.parameters.object_options.convergence_tolerance:.3e}" + ) break - if psi_previous is not None: - update_norm = tike.linalg.mnorm(self.parameters.psi[0] - psi_previous) - self.parameters.object_options.update_mnorm.append(update_norm.get()) - logger.info(f"The object update mean-norm is {update_norm:.3e}") - if update_norm < self.parameters.object_options.convergence_tolerance: - logger.info(f"The object seems converged. {update_norm:.3e} < {self.parameters.object_options.convergence_tolerance:.3e}") - break - psi_previous = cp.copy(self.parameters.psi[0]) def get_result(self): """Return the current parameter estimates.""" From 68e6bea929c5ed990d3f10706abc77164de71a02 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 3 May 2023 15:13:52 -0500 Subject: [PATCH 13/15] BUG: Use correct options class --- src/tike/ptycho/solvers/adam.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tike/ptycho/solvers/adam.py b/src/tike/ptycho/solvers/adam.py index 9fa6b37f..c814ca12 100644 --- a/src/tike/ptycho/solvers/adam.py +++ b/src/tike/ptycho/solvers/adam.py @@ -245,8 +245,8 @@ def _update_all( g=dprobe, v=probe_options.v, m=probe_options.m, - vdecay=object_options.vdecay, - mdecay=object_options.mdecay, + vdecay=probe_options.vdecay, + mdecay=probe_options.mdecay, ) probe[0] = probe[0] - algorithm_options.step_length * dprobe if probe_options.multigrid_update is not None: From c12ea190e7acb68c9c758bbfeb8e3ea3d8266ef8 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 3 May 2023 15:14:44 -0500 Subject: [PATCH 14/15] NEW: Use minibatches for new multi-grid method --- src/tike/ptycho/ptycho.py | 96 +++++++++++++++++++++------------------ 1 file changed, 53 insertions(+), 43 deletions(-) diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 8f414b70..9a668812 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -376,10 +376,8 @@ def __enter__(self): self.comm,) if self.parameters.eigen_probe is not None: - self.parameters.eigen_probe = self.comm.pool.bcast([ - self.parameters.eigen_probe.astype( - tike.precision.cfloating) - ]) + self.parameters.eigen_probe = self.comm.pool.bcast( + [self.parameters.eigen_probe.astype(tike.precision.cfloating)]) if self.parameters.position_options is not None: # TODO: Consider combining put/split, get/join operations? @@ -428,8 +426,7 @@ def iterate(self, num_iter: int) -> None: self.parameters.probe = self.comm.pool.map( constrain_probe_sparsity, self.parameters.probe, - f=self.parameters.probe_options - .force_sparsity, + f=self.parameters.probe_options.force_sparsity, ) if self.parameters.probe_options.force_orthogonality: @@ -810,25 +807,29 @@ def reconstruct_multigrid_new( ) as context: if context.parameters.object_options.multigrid_update is not None: - grad_psi = context.comm.pool.map( - context.operator.grad_psi, - context.data, - context.parameters.psi, - context.parameters.scan, - context.parameters.probe, - ) - grad_psi = context.comm.Allreduce_reduce_gpu(grad_psi)[0] + grad_psi = 0 + for n in range(context.parameters.algorithm_options.num_batch): + _grad_psi = context.comm.pool.map( + context.operator.grad_psi, + context.comm.pool.map(tike.opt.get_batch, context.data, context.batches, n=n), + context.parameters.psi, + context.comm.pool.map(tike.opt.get_batch, context.parameters.scan, context.batches, n=n), + context.parameters.probe, + ) + grad_psi += context.comm.Allreduce_reduce_gpu(_grad_psi)[0] context.parameters.object_options.multigrid_update += -grad_psi if context.parameters.probe_options.multigrid_update is not None: - grad_probe = context.comm.pool.map( - context.operator.grad_probe, - context.data, - context.parameters.psi, - context.parameters.scan, - context.parameters.probe, - ) - grad_probe = context.comm.Allreduce_reduce_gpu(grad_probe)[0] + grad_probe = 0 + for n in range(context.parameters.algorithm_options.num_batch): + _grad_probe = context.comm.pool.map( + context.operator.grad_probe, + context.comm.pool.map(tike.opt.get_batch, context.data, context.batches, n=n), + context.parameters.psi, + context.comm.pool.map(tike.opt.get_batch, context.parameters.scan, context.batches, n=n), + context.parameters.probe, + ) + grad_probe += context.comm.Allreduce_reduce_gpu(_grad_probe)[0] context.parameters.probe_options.multigrid_update += -grad_probe logging.info(f'Multigrid level {level} pre-smoothing') @@ -845,31 +846,39 @@ def reconstruct_multigrid_new( data.shape[-1] // 2, ) - grad_psi = context.comm.pool.map( - context.operator.grad_psi, - context.data, - context.parameters.psi, - context.parameters.scan, - context.parameters.probe, - ) - grad_psi = context.comm.Allreduce_reduce_cpu(grad_psi) + grad_psi = 0 + for n in range(context.parameters.algorithm_options.num_batch): + _grad_psi = context.comm.pool.map( + context.operator.grad_psi, + context.comm.pool.map(tike.opt.get_batch, context.data, context.batches, n=n), + context.parameters.psi, + context.comm.pool.map(tike.opt.get_batch, context.parameters.scan, context.batches, n=n), + context.parameters.probe, + ) + grad_psi += context.comm.Allreduce_reduce_cpu(_grad_psi) if context.parameters.object_options.multigrid_update is None: - parameters_coarser.object_options.multigrid_update = interp(grad_psi, 0.5) + parameters_coarser.object_options.multigrid_update = interp( + grad_psi, 0.5) else: - parameters_coarser.object_options.multigrid_update += interp(grad_psi, 0.5) - - grad_probe = context.comm.pool.map( - context.operator.grad_probe, - context.data, - context.parameters.psi, - context.parameters.scan, - context.parameters.probe, - ) - grad_probe = context.comm.Allreduce_reduce_cpu(grad_probe) + parameters_coarser.object_options.multigrid_update += interp( + grad_psi, 0.5) + + grad_probe = 0 + for n in range(context.parameters.algorithm_options.num_batch): + _grad_probe = context.comm.pool.map( + context.operator.grad_probe, + context.comm.pool.map(tike.opt.get_batch, context.data, context.batches, n=n), + context.parameters.psi, + context.comm.pool.map(tike.opt.get_batch, context.parameters.scan, context.batches, n=n), + context.parameters.probe, + ) + grad_probe += context.comm.Allreduce_reduce_cpu(_grad_probe) if context.parameters.probe_options.multigrid_update is None: - parameters_coarser.probe_options.multigrid_update = interp(grad_probe, 0.5) + parameters_coarser.probe_options.multigrid_update = interp( + grad_probe, 0.5) else: - parameters_coarser.probe_options.multigrid_update += interp(grad_probe, 0.5) + parameters_coarser.probe_options.multigrid_update += interp( + grad_probe, 0.5) parameters_coarser_updated = reconstruct_multigrid_new( data=data_coarser, @@ -884,6 +893,7 @@ def reconstruct_multigrid_new( context.parameters.algorithm_options.times = parameters_coarser_updated.algorithm_options.times context.parameters.algorithm_options.costs = parameters_coarser_updated.algorithm_options.costs + context.parameters.object_options.update_mnorm = parameters_coarser_updated.object_options.update_mnorm def update_multi(x, gamma, dir): From 4bb9ee6781d429bbf3d5e21430b85dbfafcea0bf Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 3 May 2023 17:00:57 -0500 Subject: [PATCH 15/15] TST: Add test for new multigrid --- tests/ptycho/test_multigrid.py | 43 ++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/ptycho/test_multigrid.py b/tests/ptycho/test_multigrid.py index 617ddbd8..9b3f25c9 100644 --- a/tests/ptycho/test_multigrid.py +++ b/tests/ptycho/test_multigrid.py @@ -111,6 +111,49 @@ def template_consistent_algorithm(self, *, data, params): return parameters +@unittest.skipIf( + _mpi_size > 1, + reason="MPI not implemented for multi-grid.", +) +class ReconMultiGridNew(): + """Test ptychography multi-grid reconstruction method.""" + + def interp(self, x, f): + pass + + def template_consistent_algorithm(self, *, data, params): + """Check ptycho.solver.algorithm for consistency.""" + if _mpi_size > 1: + raise NotImplementedError() + + with cp.cuda.Device(self.gpu_indices[0]): + parameters = tike.ptycho.reconstruct_multigrid_new( + parameters=params, + data=self.data, + num_gpu=self.gpu_indices, + use_mpi=self.mpi_size > 1, + num_levels=2, + interp=self.interp, + ) + + print() + print('\n'.join( + f'{c[0]:1.3e}' for c in parameters.algorithm_options.costs)) + return parameters + + +class TestPtychoReconMultiGridMean( + ReconMultiGridNew, + PtychoRecon, + unittest.TestCase, +): + + post_name = '-multigrid-mean' + + def interp(self, x, f): + return _resize_mean(x, f) + + class TestPtychoReconMultiGridFFT( ReconMultiGrid, PtychoRecon,