Source code for blissoda.id26.scans_aggregator

from __future__ import annotations

import json
import logging
import os
from contextlib import contextmanager
from typing import Any

from ..import_utils import unavailable_module

try:
    import gevent
except ImportError as ex:
    gevent = unavailable_module(ex)
import numpy as np

from blissoda.resources import resource_filename

from ..bliss_globals import current_session  # type: ignore
from ..ewoks_utils import get_future
from ..ewoks_utils import submit
from ..flint.access import WithFlintAccess
from ..persistent.parameters import ParameterInfo
from ..processor import BaseProcessor
from ..processor import BlissScanType  # type: ignore
from .plots import AggregatedScansPlot

logger = logging.getLogger(__name__)


[docs] class ScansAggregator( BaseProcessor, parameters=[ ParameterInfo("workflow", category="workflows"), ParameterInfo("data_mappings", category="workflows"), ParameterInfo("filename", category="workflows"), ParameterInfo("start_scan_id", category="workflows"), ParameterInfo("stop_scan_id", category="workflows"), ParameterInfo("aggregation_mode", category="workflows"), ParameterInfo("use_daxs", category="workflows"), ], ): DEFAULT_PARAMETERS = { "workflow": resource_filename("id26", "scans_aggregator.json"), "data_mappings": { "x": ".1/measurement/elapsed_time", "signal": ".1/measurement/sim_gaussian_1", }, "filename": None, "start_scan_id": None, "stop_scan_id": None, "aggregation_mode": "fraction of sums", "use_daxs": True, "trigger_at": "END", "_enabled": False, "_clear_plot": True, } def __init__( self, config: dict[str, Any] | None = None, defaults: dict[str, Any] | None = None, **deprecated_defaults: dict[str, Any], ) -> None: defaults = self._merge_defaults(deprecated_defaults, defaults) for parameter, value in self.DEFAULT_PARAMETERS.items(): defaults.setdefault(parameter, value) super().__init__(config=config, defaults=defaults) self.handler = AggregatedScansHandler()
[docs] def reset(self): for parameter, value in self.DEFAULT_PARAMETERS.items(): if parameter == "_enabled" and self._enabled: self.disable() # Only disable if currently enabled. elif parameter == "data_mappings": continue # Do not reset data mappings else: setattr(self, parameter, value) self.handler.clear()
def _get_current_scan_id(self) -> None | int: if current_session is None: raise ValueError("No current session available to get the scan from.") try: scan = current_session.scans[-1] except (AttributeError, IndexError): raise ValueError("No scans found in the current session.") scan_id = scan.scan_info["scan_nb"] if not isinstance(scan_id, int): raise ValueError("The id of the current scan is not a valid integer.") return scan_id def _get_start_scan_id(self) -> None | int: return self._get_parameter("start_scan_id") # type: ignore def _set_start_scan_id(self, value: None | int) -> None: if value is None: self._set_parameter("start_scan_id", None) return if value < 0: current_scan_id = self._get_current_scan_id() if not isinstance(current_scan_id, int): raise ValueError( "The id of the current scan must be be a valid integer before" " using negative start_scan_id." ) value = value + current_scan_id + 1 if value == 0: raise ValueError("The start_scan_id cannot be zero.") self._set_parameter("start_scan_id", value) def _get_stop_scan_id(self) -> None | int: return self._get_parameter("stop_scan_id") # type: ignore def _set_stop_scan_id(self, value: None | int) -> None: if value is None: self._set_parameter("stop_scan_id", None) return if value < 0: current_scan_id = self._get_current_scan_id() if not isinstance(current_scan_id, int): raise ValueError( "The id of the current scan must be be a valid integer before" " using negative stop_scan_id." ) value = value + current_scan_id + 1 if value == 0: raise ValueError("The stop_scan_id cannot be zero.") self._set_parameter("stop_scan_id", value) def _get_workflow(self) -> dict: with open(self.workflow) as wf: return json.load(wf) def _get_workflow_inputs(self, scan) -> list: if scan is not None: current_scan_id = self._get_current_scan_id() if self._get_start_scan_id() is None: self._set_start_scan_id(current_scan_id) self._set_stop_scan_id(current_scan_id) return [ { "name": "data_mappings", "value": dict(self.data_mappings), }, { "name": "filename", "value": self._get_filename(scan), }, { "name": "start_scan_id", "value": self._get_parameter("start_scan_id"), }, { "name": "stop_scan_id", "value": self._get_parameter("stop_scan_id"), }, { "name": "aggregation_mode", "value": self._get_parameter("aggregation_mode"), }, { "name": "use_daxs", "value": self._get_parameter("use_daxs"), }, ] def _get_filename(self, scan=None) -> str: if self.filename is not None: return self.filename if scan is not None: filename = scan.scan_info.get("filename") if filename is not None: self.filename = filename return filename else: raise ValueError("Scan has no filename.") if current_session is None: raise ValueError("No current session available to get the filename from.") return current_session.scan_saving.filename def _get_processed_data_path(self) -> str: if self.filename is None: raise ValueError("Failed to get processed data path.") stem = os.path.splitext(os.path.basename(self.filename))[0] if current_session is not None: scan_saving = current_session.scan_saving root = os.path.join( scan_saving.base_path, scan_saving.proposal_dirname, scan_saving.beamline, scan_saving.proposal_session_name, "PROCESSED_DATA", ) else: root = "" start_scan_id = self._get_start_scan_id() stop_scan_id = self._get_stop_scan_id() return os.path.join( root, f"{stem}_aggregated_scans_{start_scan_id}_to_{stop_scan_id}.dat" ) def _trigger_workflow_on_new_scan(self, scan: BlissScanType | None) -> dict | None: workflow = self._get_workflow() workflow_inputs = self._get_workflow_inputs(scan) future = submit( args=(workflow,), kwargs={"inputs": workflow_inputs, "outputs": [{"all": False}]}, ) clear_plot = self._get_parameter("_clear_plot") # Consolidate the workflow inputs into a single dictionary. workflow_inputs = { workflow_input["name"]: workflow_input["value"] for workflow_input in workflow_inputs } self.handler.handle_workflow_result( future.uuid, clear_plot=clear_plot, workflow_inputs=workflow_inputs ) @contextmanager def _temporary_scan_ids( self, start_scan_id: int, stop_scan_id: int, ): """Temporarily override scan IDs, restoring previous values on exit.""" prev_start_scan_id = self._get_start_scan_id() prev_stop_scan_id = self._get_stop_scan_id() self._set_start_scan_id(start_scan_id) self._set_stop_scan_id(stop_scan_id) try: yield finally: self._set_start_scan_id(prev_start_scan_id) self._set_stop_scan_id(prev_stop_scan_id) def __call__( self, start_scan_id: int, stop_scan_id: int, clear_plot: bool = True, save_txt: bool = False, ) -> None: if start_scan_id > stop_scan_id: logger.warning( f"Invalid scan indices; the start_scan_id ({start_scan_id}) " f"should be smaller than the stop_scan_id ({stop_scan_id})." ) return self._set_parameter("_clear_plot", clear_plot) with self._temporary_scan_ids(start_scan_id, stop_scan_id): try: self._trigger_workflow_on_new_scan(None) except Exception as e: logger.warning(e, exc_info=True) if save_txt: result = self.handler.get_result() if result is None: logger.warning("No aggregated data available to save.") return x = result["measurement"]["x"] signal = result["measurement"]["signal"] filename = self._get_processed_data_path() os.makedirs(os.path.dirname(filename), exist_ok=True) np.savetxt(filename, np.column_stack((x, signal)), header="x\tsignal") print(f"Data saved to {filename}.")
[docs] class AggregatedScansHandler(WithFlintAccess): def __init__(self) -> None: super().__init__() self._workflow_inputs = None self._result = None
[docs] def clear(self): plot = self._get_plot() plot.submit("clear") plot.xlabel = "X" plot.ylabel = "Y" plot.title = ""
[docs] def handle_workflow_result(self, *args, **kwargs) -> None: """Handle workflow result in a background task to recover data.""" gevent.spawn(self._handle_workflow_result, *args, **kwargs)
def _handle_workflow_result( self, task_id, timeout: int = 60, clear_plot: bool = True, workflow_inputs=None ) -> None: logger.info(f"Recovering processed data for task {task_id}") try: result = get_future(task_id).get(timeout=timeout) self._result = result # Merge workflow inputs back into the result for reference. self._workflow_inputs = workflow_inputs self.plot_data(clear_plot=clear_plot) except Exception as e: if not str(e): logger.error("The scans aggregator workflow has failed.") else: logger.error(e) self._result = None def _get_plot(self, plot_id="Aggregated Scans", plot_cls=AggregatedScansPlot): return super()._get_plot(plot_id, plot_cls) def _get_filename(self): if self._workflow_inputs is None: return "" filename = self._workflow_inputs.get("filename", "") return filename def _get_axes_labels(self) -> dict[str, str]: axes_labels = {"x": "X", "y": "Y"} if not self._workflow_inputs: return axes_labels data = self._workflow_inputs.get("data_mappings", {}) if not data: return axes_labels if "x" in data: axes_labels["x"] = data["x"].split("/")[-1] if "signal" in data: axes_labels["y"] = data["signal"].split("/")[-1] if "monitor" in data: axes_labels["y"] += " / " + data["monitor"].split("/")[-1] return axes_labels
[docs] def get_result(self) -> dict[str, Any] | None: return self._result
[docs] def plot_data(self, clear_plot: bool = True): if not self._result or self._workflow_inputs is None: return plot = self._get_plot() if clear_plot: plot.submit("clear") scans = self._result["scans"] for scan_id in scans: color = "gray" legend = f"Scan {scan_id}" if scan_id == max(scans): color = "black" legend = legend + " (last)" plot.add_curve( scans[scan_id]["x"], scans[scan_id]["signal"], legend=legend, linestyle="-", color=color, ) measurement = self._result["measurement"] plot.add_curve( measurement["x"], measurement["signal"], legend=self._workflow_inputs["aggregation_mode"].capitalize(), linestyle="-", color="red", linewidth=2, ) filename = self._get_filename() plot.title = f"{filename}" axes_labels = self._get_axes_labels() plot.xlabel = axes_labels["x"] plot.ylabel = axes_labels["y"]