Skip to content
32 changes: 24 additions & 8 deletions phaser/engines/common/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
from math import prod
import typing as t

import numpy
from numpy.typing import NDArray

Expand All @@ -20,7 +19,14 @@

class ClampObjectAmplitude:
def __init__(self, args: None, props: ClampObjectAmplitudeProps):
self.amplitude = props.amplitude
self.min: t.Optional[float]
self.max: t.Optional[float]

if isinstance(props.amplitude, list):
self.min, self.max = props.amplitude
else:
self.min = None
self.max = props.amplitude

def init_state(self, sim: ReconsState) -> None:
return None
Expand All @@ -29,19 +35,29 @@ def apply_group(self, group: NDArray[numpy.integer], sim: ReconsState, state: No
return self.apply_iter(sim, state)

def apply_iter(self, sim: ReconsState, state: None) -> t.Tuple[ReconsState, None]:
amp = to_real_dtype(sim.object.data.dtype)(self.amplitude)
sim.object.data = clamp_amplitude(sim.object.data, amp)
cast = to_real_dtype(sim.object.data.dtype)
sim.object.data = clamp_amplitude(sim.object.data, self.min, self.max)
return (sim, None)


@partial(jit, donate_argnames=('obj',), cupy_fuse=True)
def clamp_amplitude(obj: NDArray[numpy.complexfloating], amplitude: t.Union[float, numpy.floating]) -> NDArray[numpy.complexfloating]:
def clamp_amplitude(obj: NDArray[numpy.complexfloating], min: t.Optional[float], max: t.Optional[float]) -> NDArray[numpy.complexfloating]:
xp = get_array_module(obj)

obj_amp = xp.abs(obj)
scale = xp.minimum(obj_amp, amplitude) / obj_amp
return obj * scale
new_amp = obj_amp

if min is not None and max is not None:
new_amp = xp.clip(new_amp, min, max)
elif min is not None:
new_amp = xp.maximum(new_amp, min)
elif max is not None:
new_amp = xp.minimum(new_amp, max)
else:
return obj

scale = xp.where(obj_amp > 0, new_amp / obj_amp, 0.0) #no divide by 0
return obj * scale

class LimitProbeSupport:
def __init__(self, args: None, props: LimitProbeSupportProps):
Expand Down Expand Up @@ -548,4 +564,4 @@ def img_grad(img: numpy.ndarray) -> t.Tuple[numpy.ndarray, numpy.ndarray]:
return (
xp.diff(img, axis=-2, append=img[..., -1:, :]),
xp.diff(img, axis=-1, append=img[..., :, -1:]),
)
)
5 changes: 2 additions & 3 deletions phaser/hooks/regularization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import typing as t

import numpy
from numpy.typing import NDArray

Expand Down Expand Up @@ -35,7 +34,7 @@ def calc_loss_group(self, group: NDArray[numpy.integer], sim: 'ReconsState', sta


class ClampObjectAmplitudeProps(Dataclass):
amplitude: float = 1.1
amplitude: t.Union[float, t.List[t.Optional[float]]] = 1.1


class LimitProbeSupportProps(Dataclass):
Expand Down Expand Up @@ -120,4 +119,4 @@ class CostRegularizerHook(Hook[None, CostRegularizer]):
'probe_recip_tv': ('phaser.engines.common.regularizers:ProbeRecipTotalVariation', TVRegularizerProps),
'probe_recip_tikh': ('phaser.engines.common.regularizers:ProbeRecipTikhonov', CostRegularizerProps),
'probe_recip_tikhonov': ('phaser.engines.common.regularizers:ProbeRecipTikhonov', CostRegularizerProps),
}
}