Source code for blissoda.id26.rixs_plotter

from __future__ import annotations

import json
import logging
import os
from contextlib import contextmanager
from typing import Any

from ..import_utils import unavailable_module

try:
    import gevent
except ImportError as ex:
    gevent = unavailable_module(ex)
from blissoda.resources import resource_filename

from ..bliss_globals import current_session  # type: ignore
from ..ewoks_utils import get_future
from ..ewoks_utils import submit
from ..flint.access import WithFlintAccess
from ..persistent.parameters import ParameterInfo
from ..processor import BaseProcessor
from ..processor import BlissScanType  # type: ignore
from ..utils.directories import get_dataset_processed_dir
from .plots import RixsPlot

logger = logging.getLogger(__name__)


[docs] class RixsPlotter( BaseProcessor, parameters=[ ParameterInfo("workflow", category="workflows"), ParameterInfo("data_mappings", category="workflows"), ParameterInfo("filename", category="workflows"), ParameterInfo("start_scan_id", category="workflows"), ParameterInfo("stop_scan_id", category="workflows"), ParameterInfo("conc_corr_scan_id", category="workflows"), ParameterInfo("y_axis", category="workflows"), ], ): """Online plotter for RIXS planes displayed in a Flint widget. Typically instantiated in the Bliss session setup script. Use ``enable()`` / ``disable()`` to start/stop automatic plotting on each new scan, or call the instance directly with a scan range:: rixs_plotter(start_scan_id, stop_scan_id) Call ``reset()`` before ``enable()`` to start a new map from scratch. The ``y_axis`` attribute can be set to either: - ``"emission energy"`` (default): plots emission energy on Y. - ``"energy transfer"``: plots (incident energy - emission energy) on Y. """ DEFAULT_PARAMETERS = { "workflow": resource_filename("id26", "rixs_plotter.json"), "data_mappings": { "x": ".1/measurement/elapsed_time", "y": ".1/instrument/positioners/xes_en", "signal": ".1/measurement/sim_gaussian_1", }, "filename": None, "start_scan_id": None, "stop_scan_id": None, "conc_corr_scan_id": None, "y_axis": "emission energy", "trigger_at": "END", "_enabled": False, "_clear_plot": True, } def __init__( self, config: dict[str, Any] | None = None, defaults: dict[str, Any] | None = None, **deprecated_defaults: dict[str, Any], ) -> None: defaults = self._merge_defaults(deprecated_defaults, defaults) for parameter, value in self.DEFAULT_PARAMETERS.items(): defaults.setdefault(parameter, value) super().__init__(config=config, defaults=defaults) self.handler = RixsHandler()
[docs] def reset(self): for parameter, value in self.DEFAULT_PARAMETERS.items(): if parameter == "_enabled" and self._enabled: self.disable() # Only disable if currently enabled. elif parameter == "data_mappings": continue # Do not reset data mappings. else: setattr(self, parameter, value) self.handler.clear()
def _get_current_scan_id(self) -> None | int: if current_session is None: raise ValueError("No current session available to get the scan from.") try: scan = current_session.scans[-1] except (AttributeError, IndexError): raise ValueError("No scans found in the current session.") scan_id = scan.scan_info["scan_nb"] if not isinstance(scan_id, int): raise ValueError("The id of the current scan is not a valid integer.") return scan_id def _get_start_scan_id(self) -> None | int: return self._get_parameter("start_scan_id") # type: ignore def _set_start_scan_id(self, value: None | int) -> None: if value is None: self._set_parameter("start_scan_id", None) return if value < 0: current_scan_id = self._get_current_scan_id() if not isinstance(current_scan_id, int): raise ValueError( "The id of the current scan must be a valid integer before" " using negative start_scan_id." ) value = value + current_scan_id + 1 if value == 0: raise ValueError("The start_scan_id cannot be zero.") self._set_parameter("start_scan_id", value) def _get_stop_scan_id(self) -> None | int: return self._get_parameter("stop_scan_id") # type: ignore def _set_stop_scan_id(self, value: None | int) -> None: if value is None: self._set_parameter("stop_scan_id", None) return if value < 0: current_scan_id = self._get_current_scan_id() if not isinstance(current_scan_id, int): raise ValueError( "The id of the current scan must be a valid integer before" " using negative stop_scan_id." ) value = value + current_scan_id + 1 if value == 0: raise ValueError("The stop_scan_id cannot be zero.") self._set_parameter("stop_scan_id", value) def _get_conc_corr_scan_id(self) -> None | int: return self._get_parameter("conc_corr_scan_id") # type: ignore def _get_y_axis(self) -> str: return self._get_parameter("y_axis") # type: ignore def _set_conc_corr_scan_id(self, value: None | int) -> None: if value is None: self._set_parameter("conc_corr_scan_id", None) return if value < 0: raise ValueError( "The id of the concentration correction scan cannot be negative." ) if value == 0: raise ValueError( "The id of the concentration correction scan cannot be zero." ) self._set_parameter("conc_corr_scan_id", value) def _get_workflow(self) -> dict: with open(self.workflow) as wf: return json.load(wf) def _get_workflow_inputs(self, scan) -> list: if scan is not None: current_scan_id = self._get_current_scan_id() if self._get_start_scan_id() is None: self._set_start_scan_id(current_scan_id) self._set_stop_scan_id(current_scan_id) return [ { "name": "data_mappings", "value": dict(self.data_mappings), }, { "name": "filename", "value": self._get_filename(scan), }, { "name": "start_scan_id", "value": self._get_start_scan_id(), }, { "name": "stop_scan_id", "value": self._get_stop_scan_id(), }, { "name": "conc_corr_scan_id", "value": self._get_conc_corr_scan_id(), }, { "name": "y_axis", "value": self._get_y_axis(), }, ] def _get_filename(self, scan=None) -> str: if self.filename is not None: return self.filename if scan is not None: filename = scan.scan_info.get("filename") if filename is not None: self.filename = filename return filename else: raise ValueError("Scan has no filename.") if current_session is None: raise ValueError("No current session available to get the filename from.") return current_session.scan_saving.filename def _trigger_workflow_on_new_scan(self, scan: BlissScanType | None) -> None: workflow = self._get_workflow() workflow_inputs = self._get_workflow_inputs(scan) start_scan_id = self._get_start_scan_id() stop_scan_id = self._get_stop_scan_id() filename = self._get_filename(scan) out_dirname = get_dataset_processed_dir(filename) kwargs = { "inputs": workflow_inputs, "outputs": [{"all": False}], } if os.path.exists(out_dirname): kwargs["convert_destination"] = os.path.join( out_dirname, "workflows", f"rixs_plotter_{start_scan_id}_to_{stop_scan_id}.json", ) future = submit( args=(workflow,), kwargs=kwargs, ) clear_plot = self._get_parameter("_clear_plot") # Consolidate the workflow inputs into a single dictionary. workflow_inputs = { workflow_input["name"]: workflow_input["value"] for workflow_input in workflow_inputs } self.handler.handle_workflow_result( future.task_id, clear_plot=clear_plot, workflow_inputs=workflow_inputs ) @contextmanager def _temporary_scan_ids( self, start_scan_id: int, stop_scan_id: int, conc_corr_scan_id: int | None = None, ): """Temporarily override scan IDs, restoring previous values on exit.""" prev_start_scan_id = self._get_start_scan_id() prev_stop_scan_id = self._get_stop_scan_id() prev_conc_corr_scan_id = self._get_conc_corr_scan_id() self._set_start_scan_id(start_scan_id) self._set_stop_scan_id(stop_scan_id) self._set_conc_corr_scan_id(conc_corr_scan_id) try: yield finally: self._set_start_scan_id(prev_start_scan_id) self._set_stop_scan_id(prev_stop_scan_id) self._set_conc_corr_scan_id(prev_conc_corr_scan_id) def __call__( self, start_scan_id: int, stop_scan_id: int, conc_corr_scan_id: int | None = None, clear_plot: bool = True, ) -> None: if start_scan_id > stop_scan_id: logger.warning( f"Invalid scan indices; the start_scan_id ({start_scan_id}) " f"should be smaller than the stop_scan_id ({stop_scan_id})." ) return self._set_parameter("_clear_plot", clear_plot) with self._temporary_scan_ids(start_scan_id, stop_scan_id, conc_corr_scan_id): try: self._trigger_workflow_on_new_scan(None) except Exception as e: logger.warning(e, exc_info=True)
[docs] class RixsHandler(WithFlintAccess): def __init__(self) -> None: super().__init__() self._workflow_inputs = None self._result = None
[docs] def clear(self): plot = self._get_plot() plot.submit("clear") plot.xlabel = "X" plot.ylabel = "Y" plot.title = ""
[docs] def handle_workflow_result(self, *args, **kwargs) -> None: """Handle workflow result in a background task to recover data.""" gevent.spawn(self._handle_workflow_result, *args, **kwargs)
def _handle_workflow_result( self, task_id, timeout: int = 60, clear_plot: bool = True, workflow_inputs=None ) -> None: logger.info(f"Recovering processed data for task {task_id}") try: result = get_future(task_id).get(timeout=timeout) self._result = result # Merge workflow inputs back into the result for reference. self._workflow_inputs = workflow_inputs self.plot_data(clear_plot=clear_plot) except Exception as e: if not str(e): logger.error("The RIXS plotter workflow has failed.") else: logger.error(e) self._result = None def _get_plot(self, plot_id="Rixs Plot", plot_cls=RixsPlot): return super()._get_plot(plot_id, plot_cls) def _get_axes_labels(self) -> dict[str, str]: axes_labels = {"x": "X", "y": "Y"} if not self._workflow_inputs: return axes_labels data_mappings = self._workflow_inputs.get("data_mappings", {}) labels = self._data_mappings_to_labels(data_mappings) xlabel = labels.get("x", "X") ylabel = labels.get("y", "Y") axes_labels["x"] = f"{xlabel} (keV)" y_axis = self._workflow_inputs.get("y_axis", "emission energy") if y_axis == "energy transfer": axes_labels["y"] = f"{xlabel} - {ylabel} (keV)" else: axes_labels["y"] = f"{ylabel} (keV)" return axes_labels def _get_filename(self): if self._workflow_inputs is None: return "" filename = self._workflow_inputs.get("filename", "") return filename def _get_title(self) -> str: if not self._workflow_inputs: return "Rixs Plot" filename = self._get_filename() labels = self._data_mappings_to_labels( self._workflow_inputs.get("data_mappings", {}) ) if labels["monitor"]: return f"{filename} \n {labels['signal']} / {labels['monitor']}" return f"{filename} \n {labels['signal']}" def _data_mappings_to_labels(self, data_mappings) -> dict[str, str]: """Convert data mapping paths to human-readable labels. For each key in data_mappings (e.g., "x", "y", "signal"), extract the last component of the path (after the final "/") as the label. If no mappings are provided, return default labels. """ labels = {"x": "X", "y": "Y", "signal": "", "monitor": ""} if not data_mappings: return labels labels.update( { k: v.split("/")[-1] for k, v in data_mappings.items() if k in ["x", "y", "signal", "monitor"] } ) return labels def _get_result(self) -> dict[str, Any] | None: return self._result
[docs] def plot_data(self, clear_plot: bool = True): result = self._get_result() if not result or self._workflow_inputs is None: return plot = self._get_plot() if clear_plot: plot.submit("clear") x, y, signal = result["x"], result["y"], result["signal"] plot.title = self._get_title() axes_labels = self._get_axes_labels() plot.xlabel = axes_labels["x"] plot.ylabel = axes_labels["y"] plot.set_data( signal, origin=(x[0], y[0]), scale=(x[1] - x[0], y[1] - y[0]), resetzoom=False, ) plot.set_colormap(lut="viridis") plot.submit("setKeepDataAspectRatio", False)