diff --git a/gui/app.py b/gui/app.py index aac8551..2f6e7ec 100644 --- a/gui/app.py +++ b/gui/app.py @@ -328,6 +328,9 @@ def __init__( fusion_mode="blended", flatfield=None, darkfield=None, + zarr_version=2, + separate_timepoints=False, + blend_bias=0.5, ): super().__init__() self.tiff_path = tiff_path @@ -337,6 +340,9 @@ def __init__( self.fusion_mode = fusion_mode self.flatfield = flatfield self.darkfield = darkfield + self.zarr_version = zarr_version + self.separate_timepoints = separate_timepoints + self.blend_bias = blend_bias self.output_path = None def run(self): @@ -379,6 +385,9 @@ def run(self): downsample_factors=(self.downsample_factor, self.downsample_factor), flatfield=self.flatfield, darkfield=self.darkfield, + zarr_version=self.zarr_version, + separate_timepoints=self.separate_timepoints, + blend_bias=self.blend_bias, ) load_time = time.time() - step_start self.progress.emit(f"Loaded {tf.n_tiles} tiles ({tf.Y}x{tf.X} each) [{load_time:.1f}s]") @@ -702,7 +711,7 @@ class StitcherGUI(QMainWindow): def __init__(self): super().__init__() self.setWindowTitle("Stitcher") - self.setMinimumSize(500, 850) + self.setMinimumSize(500, 950) self.worker = None self.output_path = None @@ -868,6 +877,31 @@ def setup_ui(self): settings_layout = QVBoxLayout(settings_group) settings_layout.setSpacing(8) + # Zarr version selection + zarr_version_layout = QHBoxLayout() + zarr_version_layout.addWidget(QLabel("Output format:")) + self.zarr_version_combo = QComboBox() + self.zarr_version_combo.addItem("Zarr v2 (napari-compatible)", 2) + self.zarr_version_combo.addItem("Zarr v3 (faster, custom viewer)", 3) + self.zarr_version_combo.setCurrentIndex(0) # Default to v2 + self.zarr_version_combo.setToolTip( + "Zarr v2: Opens in standard napari and most viewers\n" + "Zarr v3: Better performance with sharding, requires custom viewer button" + ) + zarr_version_layout.addWidget(self.zarr_version_combo) + zarr_version_layout.addStretch() + settings_layout.addLayout(zarr_version_layout) + + # Separate timepoints option + self.separate_timepoints_checkbox = QCheckBox("Save timepoints as separate files") + self.separate_timepoints_checkbox.setChecked(False) + self.separate_timepoints_checkbox.setToolTip( + "When enabled: Creates t0.ome.zarr, t1.ome.zarr, etc.\n" + "When disabled: All timepoints in single file (default)\n" + "Note: Only applies when dataset has multiple timepoints" + ) + settings_layout.addWidget(self.separate_timepoints_checkbox) + self.registration_checkbox = QCheckBox("Enable registration refinement") self.registration_checkbox.setChecked(False) self.registration_checkbox.toggled.connect(self.on_registration_toggled) @@ -895,14 +929,41 @@ def setup_ui(self): # Blend pixels (shown when blending enabled) self.blend_value_widget = QWidget() self.blend_value_widget.setVisible(False) - blend_value_layout = QHBoxLayout(self.blend_value_widget) + blend_value_layout = QVBoxLayout(self.blend_value_widget) blend_value_layout.setContentsMargins(20, 0, 0, 0) - blend_value_layout.addWidget(QLabel("Blend pixels:")) + blend_value_layout.setSpacing(8) + + # Blend width row + blend_width_row = QHBoxLayout() + blend_width_row.addWidget(QLabel("Blend pixels:")) self.blend_spin = QSpinBox() self.blend_spin.setRange(1, 500) self.blend_spin.setValue(50) - blend_value_layout.addWidget(self.blend_spin) - blend_value_layout.addStretch() + blend_width_row.addWidget(self.blend_spin) + blend_width_row.addStretch() + blend_value_layout.addLayout(blend_width_row) + + # Blend bias row + blend_bias_row = QHBoxLayout() + blend_bias_row.addWidget(QLabel("Blend bias:")) + self.blend_bias_slider = QSlider(Qt.Horizontal) + self.blend_bias_slider.setRange(0, 100) + self.blend_bias_slider.setValue(50) + self.blend_bias_slider.setFixedWidth(120) + self.blend_bias_slider.setToolTip( + "Controls left/right tile contribution in overlaps:\n" + "• 50%: Equal blend (default)\n" + "• >50%: Favor left/top tiles\n" + "• <50%: Favor right/bottom tiles" + ) + self.blend_bias_slider.valueChanged.connect(self._on_blend_bias_changed) + blend_bias_row.addWidget(self.blend_bias_slider) + self.blend_bias_label = QLabel("50%") + self.blend_bias_label.setFixedWidth(35) + blend_bias_row.addWidget(self.blend_bias_label) + blend_bias_row.addStretch() + blend_value_layout.addLayout(blend_bias_row) + settings_layout.addWidget(self.blend_value_widget) layout.addWidget(settings_group) @@ -1001,6 +1062,9 @@ def on_registration_toggled(self, checked): def on_blend_toggled(self, checked): self.blend_value_widget.setVisible(checked) + def _on_blend_bias_changed(self, value): + self.blend_bias_label.setText(f"{value}%") + def on_flatfield_toggled(self, checked): # Only show/hide flatfield options; preserve any loaded/calculated data self.flatfield_options_widget.setVisible(checked) @@ -1219,15 +1283,21 @@ def run_stitching(self): if self.blend_checkbox.isChecked(): blend_val = self.blend_spin.value() blend_pixels = (blend_val, blend_val) + blend_bias = self.blend_bias_slider.value() / 100.0 fusion_mode = "blended" else: blend_pixels = (0, 0) + blend_bias = 0.5 fusion_mode = "direct" # Get flatfield if enabled flatfield = self.flatfield if self.flatfield_checkbox.isChecked() else None darkfield = self.darkfield if self.flatfield_checkbox.isChecked() else None + # Get selected zarr version + zarr_version = self.zarr_version_combo.currentData() + separate_timepoints = self.separate_timepoints_checkbox.isChecked() + self.worker = FusionWorker( self.drop_area.file_path, self.registration_checkbox.isChecked(), @@ -1236,6 +1306,9 @@ def run_stitching(self): fusion_mode, flatfield=flatfield, darkfield=darkfield, + zarr_version=zarr_version, + separate_timepoints=separate_timepoints, + blend_bias=blend_bias, ) self.worker.progress.connect(self.log) self.worker.finished.connect(self.on_fusion_finished) @@ -1374,12 +1447,22 @@ def open_in_napari(self): for scale_dir in scale_dirs: image_path = scale_dir / "image" if image_path.exists(): - store = ts.open( - { - "driver": "zarr3", - "kvstore": {"driver": "file", "path": str(image_path)}, - } - ).result() + # Try Zarr v3 first, fall back to v2 + try: + store = ts.open( + { + "driver": "zarr3", + "kvstore": {"driver": "file", "path": str(image_path)}, + } + ).result() + except Exception: + # Fall back to Zarr v2 + store = ts.open( + { + "driver": "zarr", + "kvstore": {"driver": "file", "path": str(image_path)}, + } + ).result() pyramid_data.append(store) if not pyramid_data: diff --git a/src/tilefusion/core.py b/src/tilefusion/core.py index 1c49d68..b6f3323 100644 --- a/src/tilefusion/core.py +++ b/src/tilefusion/core.py @@ -84,6 +84,10 @@ class TileFusion: Channel index for registration. multiscale_downsample : str Either "stride" (default) or "block_mean" to control multiscale reduction. + zarr_version : int + Zarr format version: 2 (default, napari-compatible) or 3 (with sharding for better performance). + separate_timepoints : bool + If True, save each timepoint as a separate .ome.zarr file (default False). """ def __init__( @@ -110,6 +114,9 @@ def __init__( region: Optional[str] = None, flatfield: Optional[np.ndarray] = None, darkfield: Optional[np.ndarray] = None, + zarr_version: int = 2, + separate_timepoints: bool = False, + blend_bias: float = 0.5, ): self.tiff_path = Path(tiff_path) if not self.tiff_path.exists(): @@ -206,12 +213,19 @@ def __init__( self._debug = bool(debug) self.metrics_filename = metrics_filename self._blend_pixels = tuple(blend_pixels) + self._blend_bias = float(blend_bias) self.channel_to_use = channel_to_use if multiscale_downsample not in ("stride", "block_mean"): raise ValueError('multiscale_downsample must be "stride" or "block_mean".') self.multiscale_downsample = multiscale_downsample + if zarr_version not in (2, 3): + raise ValueError("zarr_version must be 2 or 3") + self.zarr_version = zarr_version + + self.separate_timepoints = separate_timepoints + self._update_profiles() self.chunk_shape = (1, 1024, 1024) self.chunk_y, self.chunk_x = self.chunk_shape[-2:] @@ -413,6 +427,24 @@ def blend_pixels(self, bp: Tuple[int, int]): self._blend_pixels = tuple(bp) self._update_profiles() + @property + def blend_bias(self) -> float: + """Blend bias from 0.0 to 1.0. + + Controls the left/right contribution ratio in overlap regions: + - 0.5: symmetric blend (50/50 at overlap center) + - > 0.5: favor left/top tiles (e.g., 0.7 gives ~70% from left tile) + - < 0.5: favor right/bottom tiles (e.g., 0.3 gives ~70% from right tile) + """ + return self._blend_bias + + @blend_bias.setter + def blend_bias(self, bias: float): + if not 0.0 <= bias <= 1.0: + raise ValueError("blend_bias must be between 0.0 and 1.0.") + self._blend_bias = float(bias) + self._update_profiles() + @property def max_workers(self) -> int: """Maximum concurrent I/O workers.""" @@ -438,10 +470,11 @@ def debug(self, flag: bool): # ------------------------------------------------------------------------- def _update_profiles(self) -> None: - """Recompute 1D feather profiles from blend_pixels.""" + """Recompute 1D feather profiles from blend_pixels and blend_bias.""" by, bx = self._blend_pixels - self.y_profile = make_1d_profile(self.Y, by) - self.x_profile = make_1d_profile(self.X, bx) + bias = self._blend_bias + self.y_profile = make_1d_profile(self.Y, by, bias) + self.x_profile = make_1d_profile(self.X, bx, bias) # ------------------------------------------------------------------------- # I/O methods (delegate to format-specific loaders) @@ -881,7 +914,12 @@ def _create_fused_tensorstore(self, output_path: Union[str, Path]) -> None: self.shard_chunk = shard_chunk self.fused_ts = create_zarr_store( - out, tuple(full_shape), tuple(codec_chunk), tuple(shard_chunk), self.max_workers + out, + tuple(full_shape), + tuple(codec_chunk), + tuple(shard_chunk), + self.max_workers, + self.zarr_version, ) # ------------------------------------------------------------------------- @@ -908,8 +946,23 @@ def _fuse_tiles( else: self._fuse_tiles_full_plane(z_level=z, time_idx=t) - def _fuse_tiles_direct_plane(self, z_level: int = 0, time_idx: int = 0) -> None: - """Fuse tiles using direct placement for a single z/t plane.""" + def _fuse_tiles_direct_plane( + self, z_level: int = 0, time_idx: int = 0, output_time_idx: int = None + ) -> None: + """Fuse tiles using direct placement for a single z/t plane. + + Parameters + ---------- + z_level : int + Z-level to process. + time_idx : int + Input timepoint index to read from. + output_time_idx : int, optional + Output timepoint index to write to. If None, uses time_idx. + """ + if output_time_idx is None: + output_time_idx = time_idx + import psutil offsets = [ @@ -951,9 +1004,9 @@ def _fuse_tiles_direct_plane(self, z_level: int = 0, time_idx: int = 0) -> None: if show_progress: print("Writing to disk...") - self.fused_ts[time_idx : time_idx + 1, :, z_level : z_level + 1, :, :].write( - output - ).result() + self.fused_ts[ + output_time_idx : output_time_idx + 1, :, z_level : z_level + 1, :, : + ].write(output).result() del output else: if show_progress: @@ -976,13 +1029,31 @@ def _fuse_tiles_direct_plane(self, z_level: int = 0, time_idx: int = 0) -> None: tile_region = tile_all[:, :tile_h, :tile_w].astype(np.uint16) # Shape: (1, C, 1, h, w) self.fused_ts[ - time_idx : time_idx + 1, :, z_level : z_level + 1, oy:y_end, ox:x_end + output_time_idx : output_time_idx + 1, + :, + z_level : z_level + 1, + oy:y_end, + ox:x_end, ].write(tile_region[np.newaxis, :, np.newaxis, :, :]).result() gc.collect() - def _fuse_tiles_full_plane(self, z_level: int = 0, time_idx: int = 0) -> None: - """Fuse all tiles using full-image accumulator for a single z/t plane.""" + def _fuse_tiles_full_plane( + self, z_level: int = 0, time_idx: int = 0, output_time_idx: int = None + ) -> None: + """Fuse all tiles using full-image accumulator for a single z/t plane. + + Parameters + ---------- + z_level : int + Z-level to process. + time_idx : int + Input timepoint index to read from. + output_time_idx : int, optional + Output timepoint index to write to. If None, uses time_idx. + """ + if output_time_idx is None: + output_time_idx = time_idx offsets = [ ( int((y - self.offset[0]) / self._pixel_size[0]), @@ -1015,7 +1086,7 @@ def _fuse_tiles_full_plane(self, z_level: int = 0, time_idx: int = 0) -> None: normalize_shard(fused_block, weight_sum) # Write to 5D output: (T, C, Z, Y, X) - self.fused_ts[time_idx, :, z_level, :pad_Y, :pad_X].write( + self.fused_ts[output_time_idx, :, z_level, :pad_Y, :pad_X].write( fused_block.astype(np.uint16) ).result() @@ -1026,9 +1097,28 @@ def _fuse_tiles_full_plane(self, z_level: int = 0, time_idx: int = 0) -> None: cp.get_default_pinned_memory_pool().free_all_blocks() def _fuse_tiles_chunked_plane( - self, z_level: int = 0, time_idx: int = 0, ram_fraction: float = 0.4 + self, + z_level: int = 0, + time_idx: int = 0, + ram_fraction: float = 0.4, + output_time_idx: int = None, ) -> None: - """Fuse tiles using memory-efficient chunked processing for a single z/t plane.""" + """Fuse tiles using memory-efficient chunked processing for a single z/t plane. + + Parameters + ---------- + z_level : int + Z-level to process. + time_idx : int + Input timepoint index to read from. + ram_fraction : float + Fraction of available RAM to use. + output_time_idx : int, optional + Output timepoint index to write to. If None, uses time_idx. + """ + if output_time_idx is None: + output_time_idx = time_idx + import psutil available_ram = psutil.virtual_memory().available @@ -1045,7 +1135,9 @@ def _fuse_tiles_chunked_plane( if block_size >= max(pad_Y, pad_X): if self.n_t == 1 and self.n_z == 1: print(f"Image fits in RAM budget, using full mode") - return self._fuse_tiles_full_plane(z_level=z_level, time_idx=time_idx) + return self._fuse_tiles_full_plane( + z_level=z_level, time_idx=time_idx, output_time_idx=output_time_idx + ) show_progress = self.n_t == 1 and self.n_z == 1 if show_progress: @@ -1115,7 +1207,7 @@ def _fuse_tiles_chunked_plane( fused_block[mask] /= weight_sum[mask] # Write to 5D output: (T, C, Z, Y, X) - self.fused_ts[time_idx, :, z_level, block_y:by_end, block_x:bx_end].write( + self.fused_ts[output_time_idx, :, z_level, block_y:by_end, block_x:bx_end].write( fused_block.astype(np.uint16) ).result() @@ -1138,12 +1230,13 @@ def _create_multiscales( """Build NGFF multiscales by downsampling Y/X iteratively (not Z or T).""" inp = None for idx, factor in enumerate(factors): - out_path = omezarr_path / f"scale{idx + 1}" / "image" + out_path = omezarr_path / str(idx + 1) # Standard OME-NGFF path: "1", "2", etc. if inp is not None: del inp - prev = omezarr_path / f"scale{idx}" / "image" + prev = omezarr_path / str(idx) # Previous level: "0", "1", etc. + driver = "zarr3" if self.zarr_version == 3 else "zarr" inp = ts.open( - {"driver": "zarr3", "kvstore": {"driver": "file", "path": str(prev)}} + {"driver": driver, "kvstore": {"driver": "file", "path": str(prev)}} ).result() factor_to_use = factors[idx] // factors[idx - 1] if idx > 0 else factors[0] @@ -1151,6 +1244,11 @@ def _create_multiscales( _, _, _, Y, X = inp.shape new_y, new_x = Y // factor_to_use, X // factor_to_use + # Skip this level if dimensions would be too small + if new_y == 0 or new_x == 0: + print(f"Skipping scale{idx + 1} (dimensions would be {new_y}x{new_x})") + break + chunk_y = min(1024, new_y) chunk_x = min(1024, new_x) @@ -1185,16 +1283,14 @@ def _create_multiscales( down = down.astype(slab.dtype, copy=False) self.fused_ts[:, :, :, y0 : y0 + by, x0 : x0 + bx].write(down).result() - write_scale_group_metadata(omezarr_path / f"scale{idx + 1}") - - def _generate_ngff_zarr3_json( + def _generate_ngff_metadata( self, omezarr_path: Path, resolution_multiples: Sequence[Union[int, Sequence[int]]], dataset_name: str = "image", version: str = "0.5", ) -> None: - """Write OME-NGFF v0.5 multiscales JSON for Zarr3.""" + """Write OME-NGFF v0.5 multiscales metadata.""" write_ngff_metadata( omezarr_path, self._pixel_size, @@ -1202,6 +1298,7 @@ def _generate_ngff_zarr3_json( resolution_multiples, dataset_name, version, + self.zarr_version, ) # ------------------------------------------------------------------------- @@ -1251,22 +1348,100 @@ def run(self) -> None: else: print(f"Output size: {self.padded_shape[0]} x {self.padded_shape[1]}") - scale0 = self.output_path / "scale0" / "image" - scale0.parent.mkdir(parents=True, exist_ok=True) + # Handle separate timepoints option + if self.separate_timepoints and self.n_t > 1: + self._run_separate_timepoints() + else: + self._run_single_output() + + print(f"Done! Output: {self.output_path}") + + def _run_single_output(self) -> None: + """Run fusion with all timepoints in a single file.""" + scale0 = self.output_path / "0" # Standard OME-NGFF path + scale0.mkdir(parents=True, exist_ok=True) self._create_fused_tensorstore(output_path=scale0) print("Fusing tiles...") self._fuse_tiles() - write_scale_group_metadata(self.output_path / "scale0") - print("Building multiscale pyramid...") self._create_multiscales(self.output_path, factors=self.multiscale_factors) - self._generate_ngff_zarr3_json( + self._generate_ngff_metadata( self.output_path, resolution_multiples=self.resolution_multiples ) - print(f"Done! Output: {self.output_path}") + def _run_separate_timepoints(self) -> None: + """Run fusion with each timepoint as a separate file.""" + # Create output folder + output_folder = self.output_path + if output_folder.suffix == ".ome.zarr": + # Remove .ome.zarr suffix and use as folder + output_folder = output_folder.parent / output_folder.stem.replace(".ome", "") + output_folder.mkdir(parents=True, exist_ok=True) + + print(f"Saving {self.n_t} timepoints as separate files...") + + # Store original values that get modified during processing + original_n_t = self.n_t + original_padded_shape = self.padded_shape + original_chunk_y = self.chunk_y + original_chunk_x = self.chunk_x + + # Process each timepoint separately + for t_idx in range(original_n_t): + print(f"\n{'='*60}") + print(f"Processing timepoint {t_idx + 1}/{original_n_t}") + print(f"{'='*60}") + + # Restore values that get modified during multiscale creation + self.n_t = 1 + self.padded_shape = original_padded_shape + self.chunk_y = original_chunk_y + self.chunk_x = original_chunk_x + + # Create output for this timepoint + t_output = output_folder / f"t{t_idx}.ome.zarr" + scale0 = t_output / "0" # Standard OME-NGFF path + scale0.mkdir(parents=True, exist_ok=True) + self._create_fused_tensorstore(output_path=scale0) + + # Fuse only this timepoint + print(f"Fusing timepoint {t_idx}...") + self._fuse_tiles_single_timepoint(t_idx) + + print(f"Building multiscale pyramid for timepoint {t_idx}...") + self._create_multiscales(t_output, factors=self.multiscale_factors) + self._generate_ngff_metadata(t_output, resolution_multiples=self.resolution_multiples) + + # Restore original n_t + self.n_t = original_n_t + + # Update output_path to folder + self.output_path = output_folder + + def _fuse_tiles_single_timepoint(self, time_idx: int) -> None: + """Fuse tiles for a single timepoint. + + When processing separate timepoints, this function writes to output index 0 + regardless of the input time_idx, since each output file contains only one timepoint. + + Parameters + ---------- + time_idx : int + Input timepoint index to read from the source data. + """ + for z in range(self.n_z): + if self.n_z > 1: + print(f"Fusing z-level {z + 1}/{self.n_z}...") + + # Use existing fusion methods with specific time_idx + # Always write to output index 0 since we're creating single-timepoint outputs + mode = "blended" if self._blend_pixels != (0, 0) else "direct" + if mode == "direct": + self._fuse_tiles_direct_plane(z_level=z, time_idx=time_idx, output_time_idx=0) + else: + self._fuse_tiles_chunked_plane(z_level=z, time_idx=time_idx, output_time_idx=0) def stitch_all_regions(self) -> None: """Stitch all regions in the dataset, creating separate outputs per region. @@ -1314,6 +1489,9 @@ def stitch_all_regions(self) -> None: channel_to_use=self.channel_to_use, multiscale_downsample=self.multiscale_downsample, region=region, + zarr_version=self.zarr_version, + separate_timepoints=self.separate_timepoints, + blend_bias=self._blend_bias, ) tf.run() diff --git a/src/tilefusion/io/zarr.py b/src/tilefusion/io/zarr.py index dfd66ab..793967b 100644 --- a/src/tilefusion/io/zarr.py +++ b/src/tilefusion/io/zarr.py @@ -171,67 +171,99 @@ def create_zarr_store( chunk_shape: Tuple[int, ...], shard_chunk: Tuple[int, ...], max_workers: int = 8, + zarr_version: int = 2, ) -> ts.TensorStore: """ - Create a Zarr v3 store with sharding codec. + Create a Zarr store with optional sharding codec. Parameters ---------- output_path : Path Path for the Zarr store. shape : tuple - Full array shape (T, C, Y, X). + Full array shape (T, C, Z, Y, X). chunk_shape : tuple - Codec chunk shape. + Codec chunk shape (for v3 sharding) or main chunk shape (for v2). shard_chunk : tuple - Shard chunk shape. + Shard chunk shape (v3 only). max_workers : int I/O concurrency limit. + zarr_version : int + Zarr format version: 2 (default, napari-compatible) or 3 (with sharding). Returns ------- store : ts.TensorStore Open TensorStore for writing. """ - config = { - "context": { - "file_io_concurrency": {"limit": max_workers}, - "data_copy_concurrency": {"limit": max_workers}, - }, - "driver": "zarr3", - "kvstore": {"driver": "file", "path": str(output_path)}, - "metadata": { - "shape": list(shape), - "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": list(shard_chunk)}}, - "chunk_key_encoding": {"name": "default"}, - "codecs": [ - { - "name": "sharding_indexed", - "configuration": { - "chunk_shape": list(chunk_shape), - "codecs": [ - {"name": "bytes", "configuration": {"endian": "little"}}, - { - "name": "blosc", - "configuration": { - "cname": "zstd", - "clevel": 5, - "shuffle": "bitshuffle", + if zarr_version == 3: + # Zarr v3 with sharding (better performance, but not compatible with standard napari) + config = { + "context": { + "file_io_concurrency": {"limit": max_workers}, + "data_copy_concurrency": {"limit": max_workers}, + }, + "driver": "zarr3", + "kvstore": {"driver": "file", "path": str(output_path)}, + "metadata": { + "shape": list(shape), + "chunk_grid": { + "name": "regular", + "configuration": {"chunk_shape": list(shard_chunk)}, + }, + "chunk_key_encoding": {"name": "default"}, + "codecs": [ + { + "name": "sharding_indexed", + "configuration": { + "chunk_shape": list(chunk_shape), + "codecs": [ + {"name": "bytes", "configuration": {"endian": "little"}}, + { + "name": "blosc", + "configuration": { + "cname": "zstd", + "clevel": 5, + "shuffle": "bitshuffle", + }, }, - }, - ], - "index_codecs": [ - {"name": "bytes", "configuration": {"endian": "little"}}, - {"name": "crc32c"}, - ], - "index_location": "end", - }, - } - ], - "data_type": "uint16", - "dimension_names": ["t", "c", "z", "y", "x"], - }, - } + ], + "index_codecs": [ + {"name": "bytes", "configuration": {"endian": "little"}}, + {"name": "crc32c"}, + ], + "index_location": "end", + }, + } + ], + "data_type": "uint16", + "dimension_names": ["t", "c", "z", "y", "x"], + }, + } + else: + # Zarr v2 (compatible with standard napari and most tools) + config = { + "context": { + "file_io_concurrency": {"limit": max_workers}, + "data_copy_concurrency": {"limit": max_workers}, + }, + "driver": "zarr", + "kvstore": {"driver": "file", "path": str(output_path)}, + "metadata": { + "shape": list(shape), + "chunks": list(chunk_shape), + "dtype": " None: """ - Write OME-NGFF v0.5 multiscales JSON for Zarr3. + Write OME-NGFF multiscales metadata for Zarr. Parameters ---------- @@ -260,8 +293,12 @@ def write_ngff_metadata( dataset_name : str Name of the dataset node. version : str - NGFF version. + NGFF version. Defaults to "0.4" for Zarr v2 and "0.5" for Zarr v3. """ + # Default version based on zarr_version + if version is None: + version = "0.4" if zarr_version == 2 else "0.5" + axes = [ {"name": "t", "type": "time"}, {"name": "c", "type": "channel"}, @@ -291,7 +328,7 @@ def write_ngff_metadata( ] datasets.append( { - "path": f"scale{lvl}/{dataset_name}", + "path": str(lvl), # Standard OME-NGFF path: "0", "1", "2", etc. "coordinateTransformations": [ {"type": "scale", "scale": scale}, {"type": "translation", "translation": translation}, @@ -301,27 +338,58 @@ def write_ngff_metadata( prev_sp = spatial mult = { + "version": version, "axes": axes, "datasets": datasets, "name": dataset_name, "@type": "ngff:Image", } - metadata = { - "attributes": {"ome": {"version": version, "multiscales": [mult]}}, - "zarr_format": 3, - "node_type": "group", - } - with open(omezarr_path / "zarr.json", "w") as f: - json.dump(metadata, f, indent=2) + if zarr_version == 3: + # Zarr v3 format + metadata = { + "attributes": {"ome": {"version": version, "multiscales": [mult]}}, + "zarr_format": 3, + "node_type": "group", + } + with open(omezarr_path / "zarr.json", "w") as f: + json.dump(metadata, f, indent=2) + else: + # Zarr v2 format + metadata = {"multiscales": [mult]} + # Write .zgroup file + with open(omezarr_path / ".zgroup", "w") as f: + json.dump({"zarr_format": 2}, f) + # Write .zattrs file with OME metadata + with open(omezarr_path / ".zattrs", "w") as f: + json.dump(metadata, f, indent=2) -def write_scale_group_metadata(scale_path: Path) -> None: - """Write zarr.json for a scale group.""" - ngff = { - "attributes": {"_ARRAY_DIMENSIONS": ["t", "c", "z", "y", "x"]}, - "zarr_format": 3, - "node_type": "group", - } + +def write_scale_group_metadata(scale_path: Path, zarr_version: int = 2) -> None: + """Write metadata for a scale group. + + Parameters + ---------- + scale_path : Path + Path to the scale group directory. + zarr_version : int + Zarr format version: 2 (default) or 3. + """ scale_path.mkdir(parents=True, exist_ok=True) - with open(scale_path / "zarr.json", "w") as f: - json.dump(ngff, f, indent=2) + + if zarr_version == 3: + # Zarr v3 format + ngff = { + "attributes": {"_ARRAY_DIMENSIONS": ["t", "c", "z", "y", "x"]}, + "zarr_format": 3, + "node_type": "group", + } + with open(scale_path / "zarr.json", "w") as f: + json.dump(ngff, f, indent=2) + else: + # Zarr v2 format + ngff = {"_ARRAY_DIMENSIONS": ["t", "c", "z", "y", "x"]} + with open(scale_path / ".zgroup", "w") as f: + json.dump({"zarr_format": 2}, f) + with open(scale_path / ".zattrs", "w") as f: + json.dump(ngff, f, indent=2) diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index d0d1769..20dffed 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -50,9 +50,9 @@ def compute_ssim(arr1, arr2, win_size: int) -> float: return float(_ssim_cpu(arr1_np, arr2_np, win_size=win_size, data_range=data_range)) -def make_1d_profile(length: int, blend: int) -> np.ndarray: +def make_1d_profile(length: int, blend: int, bias: float = 0.5) -> np.ndarray: """ - Create a linear ramp profile over `blend` pixels at each end. + Create a ramp profile over `blend` pixels at each end. Parameters ---------- @@ -60,18 +60,34 @@ def make_1d_profile(length: int, blend: int) -> np.ndarray: Number of pixels. blend : int Ramp width. + bias : float + Blend bias from 0.0 to 1.0 (default 0.5). + - 0.5: symmetric linear blend (50/50 at overlap center) + - > 0.5: favor left/top tiles (e.g., 0.7 gives ~70% from left tile) + - < 0.5: favor right/bottom tiles (e.g., 0.3 gives ~70% from right tile) Returns ------- prof : (length,) float32 - Linear profile. + Blend profile. """ blend = min(blend, length // 2) prof = np.ones(length, dtype=np.float32) if blend > 0: - ramp = np.linspace(0, 1, blend, endpoint=False, dtype=np.float32) - prof[:blend] = ramp - prof[-blend:] = ramp[::-1] + # Linear position from 0 to 1 across the blend region + t = np.linspace(0, 1, blend, endpoint=False, dtype=np.float32) + + # Use power-law ramps for asymmetric blending + # bias=0.5 gives linear ramps (power=1), symmetric profile + # bias>0.5 makes left ramp rise faster, right ramp fall slower (favor left tiles) + # bias<0.5 makes left ramp rise slower, right ramp fall faster (favor right tiles) + p_rise = max(0.1, bias * 2) # Power for rising ramp (left edge) + p_fall = max(0.1, (1 - bias) * 2) # Power for falling ramp (right edge) + + # Left edge: rising ramp with power p_rise (0 -> ~1) + prof[:blend] = t**p_rise + # Right edge: falling ramp with power p_fall (~1 -> 0), using reversed t + prof[-blend:] = t[::-1] ** p_fall return prof diff --git a/tests/test_integration.py b/tests/test_integration.py index e2ad08e..c7ea66c 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -52,12 +52,37 @@ def _write_individual_tiffs_folder( json.dump(params, f) +def _get_scale_path(path: Path, level: int = 0): + """Get path to a scale level (supports both v2 and v3 formats).""" + # zarr v3 structure: scale0/image, scale1/image, ... + v3_path = path / f"scale{level}" / "image" + # ome-zarr v2 structure: 0/, 1/, ... + v2_path = path / str(level) + + if v3_path.exists(): + return v3_path + if v2_path.exists(): + return v2_path + return None + + def _read_fused_output(path: Path): - """Read the fused output from a zarr store.""" - scale0 = path / "scale0" / "image" - store = ts.open( - {"driver": "zarr3", "kvstore": {"driver": "file", "path": str(scale0)}} - ).result() + """Read the fused output from a zarr store (supports both v2 and v3 formats).""" + scale0 = _get_scale_path(path, 0) + if scale0 is None: + raise FileNotFoundError(f"No zarr data found at {path}") + + # Detect format based on path structure + if "scale" in str(scale0): + # zarr v3 + store = ts.open( + {"driver": "zarr3", "kvstore": {"driver": "file", "path": str(scale0)}} + ).result() + else: + # ome-zarr v2 + store = ts.open( + {"driver": "zarr", "kvstore": {"driver": "file", "path": str(scale0)}} + ).result() return store.read().result() @@ -100,8 +125,8 @@ def test_two_tiles_horizontal(self, tmp_path): # Verify output exists assert output_path.exists() - assert (output_path / "scale0" / "image").exists() - assert (output_path / "scale1" / "image").exists() + assert _get_scale_path(output_path, 0) is not None + assert _get_scale_path(output_path, 1) is not None # Read fused result fused = _read_fused_output(output_path) @@ -340,10 +365,10 @@ def test_pyramid_levels_exist(self, tmp_path): tf.run() # Check all scale levels exist - assert (output_path / "scale0" / "image").exists() - assert (output_path / "scale1" / "image").exists() - assert (output_path / "scale2" / "image").exists() - assert (output_path / "scale3" / "image").exists() + assert _get_scale_path(output_path, 0) is not None + assert _get_scale_path(output_path, 1) is not None + assert _get_scale_path(output_path, 2) is not None + assert _get_scale_path(output_path, 3) is not None def test_ngff_metadata(self, tmp_path): """Test that NGFF metadata is written correctly.""" @@ -363,13 +388,20 @@ def test_ngff_metadata(self, tmp_path): ) tf.run() - # Check zarr.json has multiscales metadata + # Check metadata exists (zarr.json for v3, .zattrs for v2) zarr_json = output_path / "zarr.json" - assert zarr_json.exists() - - with open(zarr_json) as f: - meta = json.load(f) - - assert "attributes" in meta - assert "ome" in meta["attributes"] - assert "multiscales" in meta["attributes"]["ome"] + zattrs = output_path / ".zattrs" + + if zarr_json.exists(): + # zarr v3 format + with open(zarr_json) as f: + meta = json.load(f) + assert "attributes" in meta + assert "ome" in meta["attributes"] + assert "multiscales" in meta["attributes"]["ome"] + else: + # ome-zarr v2 format + assert zattrs.exists() + with open(zattrs) as f: + meta = json.load(f) + assert "multiscales" in meta