Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev,jax,web]'
pip install -e '.[dev,jax,torch,web]'
- name: Test
run: pytest
14 changes: 11 additions & 3 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import pytest
import numpy

BACKENDS: set[str] = set(('cpu', 'jax', 'cuda'))
BACKENDS: set[str] = set(('cpu', 'jax', 'cupy', 'torch'))
AVAILABLE_BACKENDS: set[str] = set(('cpu',))

try:
import cupy # pyright: ignore[reportMissingImports]
if cupy.cuda.runtime.getDeviceCount() > 0:
AVAILABLE_BACKENDS.add('cuda')
except ImportError:
AVAILABLE_BACKENDS.add('cupy')
except (ImportError, RuntimeError):
pass

try:
Expand All @@ -19,6 +19,14 @@
pass


try:
import torch # pyright: ignore[reportMissingImports] # noqa: F401
torch.asarray([1, 2, 3, 4]).numpy(force=True) # ensures torch is loaded, fixes a strange error with pytest
AVAILABLE_BACKENDS.add('torch')
except ImportError:
pass


def pytest_addoption(parser: pytest.Parser):
parser.addoption("--save-expected", action="store_true", dest='save-expected', default=False,
help="Overwrite expected files with the results of tests.")
Expand Down
7 changes: 4 additions & 3 deletions examples/mos2_epie.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
---
name: "mos2_epie"
backend: cupy
backend: torch

# raw data source
raw_data:
type: empad
path: "sample_data/simulated_mos2/mos2_0.00_dstep1.0.json"
path: "~/Downloads/mos2/1/mos2/mos2_0.00_dstep1.0.json"

post_load:
- type: poisson
Expand All @@ -32,7 +32,8 @@ engines:
beta_probe: 0.5

group_constraints: []
iter_constraints: []
iter_constraints:
- type: remove_phase_ramp

update_probe: {after: 5}

Expand Down
2 changes: 1 addition & 1 deletion examples/optuna_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def save_json(self, plan: ReconsPlan, engines: t.Iterable[EngineHook]):
with open(self.trial_path / 'plan.json', 'w') as f:
json.dump(plan_json, f, indent=4)

def update_iteration(self, state: ReconsState, i: int, n: int, error: t.Optional[float] = None):
def update_iteration(self, state: ReconsState, i: int, n: int, errors: t.Dict[str, float]):
i = state.iter.total_iter

if i % MEASURE_EVERY == 0:
Expand Down
19 changes: 19 additions & 0 deletions notebooks/conventions.ipynb

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion phaser/engines/common/noise_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from phaser.hooks.solver import NoiseModel
from phaser.plan import AmplitudeNoisePlan, AnscombeNoisePlan, PoissonNoisePlan
from phaser.utils.num import get_array_module, Float
from phaser.utils.num import get_array_module, Float, to_numpy
from phaser.state import ReconsState


Expand Down
68 changes: 57 additions & 11 deletions phaser/engines/common/regularizers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial
import logging
from math import prod
import typing as t

import numpy
Expand Down Expand Up @@ -196,6 +197,10 @@ class ObjL1:
def __init__(self, args: None, props: CostRegularizerProps):
self.cost: float = props.cost

@staticmethod
def name() -> str:
return 'obj_l1'

def init_state(self, sim: ReconsState) -> None:
return None

Expand All @@ -205,14 +210,18 @@ def calc_loss_group(
xp = get_array_module(sim.object.data)

cost = xp.sum(xp.abs(sim.object.data - 1.0))
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
return (cost * cost_scale * self.cost, state)


class ObjL2:
def __init__(self, args: None, props: CostRegularizerProps):
self.cost: Float = props.cost

@staticmethod
def name() -> str:
return 'obj_l2'

def init_state(self, sim: ReconsState) -> None:
return None

Expand All @@ -222,14 +231,19 @@ def calc_loss_group(
xp = get_array_module(sim.object.data)

cost = xp.sum(abs2(sim.object.data - 1.0))
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
return (cost * cost_scale * self.cost, state)

cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
return (cost * cost_scale * self.cost, state) # type: ignore


class ObjPhaseL1:
def __init__(self, args: None, props: CostRegularizerProps):
self.cost: float = props.cost

@staticmethod
def name() -> str:
return 'obj_phase_l1'

def init_state(self, sim: ReconsState) -> None:
return None

Expand All @@ -239,14 +253,18 @@ def calc_loss_group(
xp = get_array_module(sim.object.data)

cost = xp.sum(xp.abs(xp.angle(sim.object.data)))
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)
return (cost * cost_scale * self.cost, state)


class ObjRecipL1:
def __init__(self, args: None, props: CostRegularizerProps):
self.cost: float = props.cost

@staticmethod
def name() -> str:
return 'obj_recip_l1'

def init_state(self, sim: ReconsState) -> None:
return None

Expand All @@ -261,7 +279,7 @@ def calc_loss_group(
xp.abs(fft2(xp.prod(sim.object.data, axis=0)))
)
# scale cost by fraction of the total reconstruction in the group
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state)

Expand All @@ -271,6 +289,10 @@ def __init__(self, args: None, props: TVRegularizerProps):
self.cost: float = props.cost
self.eps: float = props.eps

@staticmethod
def name() -> str:
return 'obj_tv'

def init_state(self, sim: ReconsState) -> None:
return None

Expand All @@ -289,7 +311,7 @@ def calc_loss_group(
#)
# scale cost by fraction of the total reconstruction in the group
# TODO also scale by # of pixels or similar?
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state)

Expand All @@ -298,6 +320,10 @@ class ObjTikhonov:
def __init__(self, args: None, props: CostRegularizerProps):
self.cost: float = props.cost

@staticmethod
def name() -> str:
return 'obj_tikh'

def init_state(self, sim: ReconsState) -> None:
return None

Expand All @@ -311,15 +337,19 @@ def calc_loss_group(
xp.sum(abs2(xp.diff(sim.object.data, axis=-2)))
)
# scale cost by fraction of the total reconstruction in the group
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state)
return (cost * cost_scale * self.cost, state) # type: ignore


class LayersTotalVariation:
def __init__(self, args: None, props: CostRegularizerProps):
self.cost: float = props.cost

@staticmethod
def name() -> str:
return 'layers_tv'

def init_state(self, sim: ReconsState) -> None:
return None

Expand All @@ -333,7 +363,7 @@ def calc_loss_group(

cost = xp.sum(xp.abs(xp.diff(sim.object.data, axis=0)))
# scale cost by fraction of the total reconstruction in the group
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state)

Expand All @@ -342,6 +372,10 @@ class LayersTikhonov:
def __init__(self, args: None, props: CostRegularizerProps):
self.cost: float = props.cost

@staticmethod
def name() -> str:
return 'layers_tikh'

def init_state(self, sim: ReconsState) -> None:
return None

Expand All @@ -355,15 +389,19 @@ def calc_loss_group(

cost = xp.sum(abs2(xp.diff(sim.object.data, axis=0)))
# scale cost by fraction of the total reconstruction in the group
cost_scale = (group.shape[-1] / numpy.prod(sim.scan.shape[:-1])).astype(cost.dtype)
cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype)

return (cost * cost_scale * self.cost, state)
return (cost * cost_scale * self.cost, state) # type: ignore


class ProbePhaseTikhonov:
def __init__(self, args: None, props: CostRegularizerProps):
self.cost: float = props.cost

@staticmethod
def name() -> str:
return 'probe_phase_tikh'

def init_state(self, sim: ReconsState) -> None:
return None

Expand All @@ -387,6 +425,10 @@ class ProbeRecipTikhonov:
def __init__(self, args: None, props: CostRegularizerProps):
self.cost: float = props.cost

@staticmethod
def name() -> str:
return 'probe_recip_tikh'

def init_state(self, sim: ReconsState) -> None:
return None

Expand All @@ -410,6 +452,10 @@ def __init__(self, args: None, props: TVRegularizerProps):
self.cost: float = props.cost
self.eps: float = props.eps

@staticmethod
def name() -> str:
return 'probe_recip_tv'

def init_state(self, sim: ReconsState) -> None:
return None

Expand Down
14 changes: 8 additions & 6 deletions phaser/engines/common/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import typing as t

import numpy
from numpy.typing import NDArray, DTypeLike
from numpy.typing import NDArray
from typing_extensions import Self

from phaser.utils.num import (
get_array_module, to_real_dtype, to_complex_dtype,
fft2, ifft2, is_jax, to_numpy, block_until_ready, ufunc_outer
)
from phaser.utils.misc import FloatKey, jax_dataclass, create_compact_groupings, create_sparse_groupings, shuffled
from phaser.utils.tree import tree_dataclass
from phaser.utils.misc import FloatKey, create_compact_groupings, create_sparse_groupings, shuffled
from phaser.utils.optics import fresnel_propagator, fourier_shift_filter
from phaser.state import ReconsState
from phaser.hooks.solver import NoiseModel
Expand All @@ -31,7 +32,8 @@ def __init__(
self.compact = compact
self.seed = seed
self.groups: t.Optional[t.List[NDArray[numpy.int64]]] = None
self.n_groups: int = int(numpy.ceil(numpy.prod(scan.shape[:-1]) / self.grouping))
self.n_pos = numpy.prod(scan.shape[:-1])
self.n_groups: int = int(numpy.ceil(self.n_pos / self.grouping))

def _make(self, scan: NDArray[numpy.floating], i: int = 0) -> t.List[NDArray[numpy.int64]]:
if self.compact:
Expand Down Expand Up @@ -83,7 +85,7 @@ def stream_patterns(
continue


@jax_dataclass(init=False, static_fields=('xp', 'dtype', 'noise_model', 'group_constraints', 'iter_constraints'), drop_fields=('ky', 'kx'))
@tree_dataclass(init=False, static_fields=('xp', 'dtype', 'noise_model', 'group_constraints', 'iter_constraints'), drop_fields=('ky', 'kx'))
class SimulationState:
state: ReconsState

Expand All @@ -99,7 +101,7 @@ class SimulationState:
iter_constraint_states: t.Tuple[t.Any, ...]

xp: t.Any
dtype: DTypeLike
dtype: t.Type[numpy.floating]
start_iter: int

def __init__(
Expand All @@ -109,7 +111,7 @@ def __init__(
group_constraints: t.Tuple[GroupConstraint[t.Any], ...],
iter_constraints: t.Tuple[IterConstraint[t.Any], ...],
xp: t.Any,
dtype: DTypeLike,
dtype: t.Type[numpy.floating],
noise_model_state: t.Optional[t.Any] = None,
group_constraint_states: t.Optional[t.Tuple[t.Any, ...]] = None,
iter_constraint_states: t.Optional[t.Tuple[t.Any, ...]] = None,
Expand Down
Loading