from __future__ import annotations
import functools
import logging
import time
from collections import OrderedDict
from pathlib import Path
import h5py
import numpy as np
from blissoda.tomo.utils import apply_labels
from blissoda.tomo.utils import compute_axes
from ..bliss_globals import current_session
from ..flint.plotter import BasePlotter
try:
from flint.viewers.custom_image.client import ImageView
except ImportError as ex:
from ..import_utils import unavailable_class
ImageView = unavailable_class(ex)
logger = logging.getLogger(__name__)
def _background_task(method):
"""Decorator to handle exceptions in background tasks."""
@functools.wraps(method)
def wrapper(*args, **kw):
try:
return method(*args, **kw)
except Exception as e:
logger.error("Online tomo plotter task failed (%s)", e, exc_info=True)
raise
return wrapper
[docs]
class OnlineTomoAccumulatedPlotter(BasePlotter):
"""
Plotter for online tomography reconstruction that accumulates partial slices.
Monitors the output directory for new reconstructed slice files and
progressively accumulates them to show the evolving reconstruction.
"""
TITLE = "Accumulated Reconstructed Slice"
def __init__(self, history: int = 1, retry_period: int = 3) -> None:
"""
Initialize the plotter.
:param history: Number of plots to keep (should be 1 for current scan only).
:param retry_period: Polling interval in seconds to check for new batch files.
"""
super().__init__(max_plots=history)
self.retry_period = retry_period
self._cache: OrderedDict[str, np.ndarray] = OrderedDict()
self._current_plot_widget = None
self._accumulated_slice = None
self._batch_count = 0
self._batch_size = None
[docs]
def handle_workflow_result(
self,
future,
output_path: str,
slice_index: int | str = "middle",
batch_size: int = 100,
) -> None:
"""
Monitor workflow execution and update plot as new batches are saved.
:param future: The workflow job future
:param output_path: Directory where batch files are being saved
:param slice_index: The slice index being reconstructed (for display only)
:param batch_size: Number of projections per batch
"""
self._spawn(
self._monitor_reconstruction,
future,
output_path,
slice_index,
batch_size,
)
@_background_task
def _monitor_reconstruction(
self,
future,
output_path: str,
slice_index: int | str,
batch_size: int,
) -> None:
"""
Background task to monitor reconstruction progress.
Polls the output directory for new batch files and updates the plot.
:param future: The workflow job future
:param output_path: Directory where batch files are being saved
:param slice_index: The slice index being reconstructed (for display only)
:param batch_size: Number of projections per batch
"""
output_dir = Path(output_path)
# Reset state for new reconstruction
self._accumulated_slice = None
self._batch_count = 0
self._batch_size = batch_size
logger.info(f"Starting monitoring of reconstruction in {output_dir}")
last_processed_count = 0
# Monitor while workflow is running
while not future.done():
last_processed_count = self._poll_for_batches(
output_dir, last_processed_count
)
if last_processed_count != self._batch_count:
self._update_plot(slice_index)
time.sleep(self.retry_period)
# Final update after workflow completes
try:
self._finalize_monitoring(output_dir, last_processed_count, slice_index)
except Exception as e:
logger.error(f"Error in final batch processing: {e}")
def _poll_for_batches(self, output_dir: Path, start_count: int) -> int:
"""
Poll the output directory for new batch files and process them.
:param output_dir: Directory containing the batch slice files
:param start_count: Number of batches already processed
:return: Updated number of batches processed
"""
try:
new_batches = self._process_new_batches(output_dir, start_count)
if new_batches > 0:
return self._batch_count
except Exception as e:
logger.debug(f"Error processing batches: {e}")
return start_count
def _finalize_monitoring(
self, output_dir: Path, last_processed_count: int, slice_index
):
"""
Final processing after workflow completion.
:param output_dir: Directory containing the batch slice files
:param last_processed_count: Number of batches already processed
:param slice_index: The slice index being reconstructed
"""
self._process_new_batches(output_dir, last_processed_count)
self._update_plot(slice_index)
projections_done = self._batch_count * self._batch_size
logger.info(
f"Reconstruction monitoring completed: "
f"{projections_done} projections processed."
)
def _process_new_batches(self, output_dir: Path, start_count: int) -> int:
"""
Process any new batch slice files that have appeared.
:param output_dir: Directory containing the batch slice files
:param start_count: Number of batches already processed
:return: Number of new batches processed
"""
if not output_dir.exists():
return 0
# Find all batch files
batch_files = sorted(output_dir.glob("*.h5"))
if not batch_files:
return 0
new_batches = 0
# Process only new files
for batch_file in batch_files[start_count:]:
try:
with h5py.File(batch_file, "r") as f:
if "reconstructed_slice" not in f:
logger.debug(f"Missing dataset in {batch_file}")
continue
slice_data = f["reconstructed_slice"][()]
# Accumulate by addition
if self._accumulated_slice is None:
self._accumulated_slice = slice_data.astype(np.float32)
else:
self._accumulated_slice += slice_data
self._batch_count += 1
new_batches += 1
logger.debug(f"Processed batch file: {batch_file.name}")
except (OSError, KeyError) as e:
logger.debug(f"Could not read {batch_file}: {e}")
continue
return new_batches
def _update_plot(self, slice_index: int | str) -> None:
"""
Update the Flint plot with the accumulated slice.
:param slice_index: The slice index being reconstructed
"""
if self._accumulated_slice is None:
return
widget = self._get_plot(self.TITLE, ImageView)
# Set title with progress information
self._set_title_with_progress(widget, slice_index)
# Compute physical axes from tomo config
x_axis, y_axis = compute_axes(self._accumulated_slice)
origin = (float(x_axis[0]), float(y_axis[0]))
dx = float(x_axis[1] - x_axis[0]) if len(x_axis) > 1 else 1.0
dy = float(y_axis[1] - y_axis[0]) if len(y_axis) > 1 else 1.0
scale = (dx, dy)
# Update the image data
widget.set_data(self._accumulated_slice, origin=origin, scale=scale)
# Apply axis labels
apply_labels(widget)
# Cache the current image
self._cache[self.TITLE] = self._accumulated_slice
self.purge_tasks()
self._purge()
def _set_title_with_progress(
self, widget: ImageView, slice_index: int | str
) -> None:
"""
Set the plot window title with progress information.
:param widget: The plot widget
:param slice_index: The slice index being reconstructed
"""
base_title = current_session.scan_saving.data_filename
# Add progress and slice info to title
nb_projections = self._batch_count * self._batch_size
title = f"{base_title} - Slice {slice_index} ({nb_projections} projections processed)"
widget.title = title
[docs]
def clear(self) -> None:
"""Clear the current plot and reset state."""
self._accumulated_slice = None
self._batch_count = 0
logger.debug("Plotter state cleared")
def _purge(self) -> None:
"""
Remove oldest images beyond history.
"""
while len(self._cache) > self._max_plots:
self._cache.popitem(last=False)