from __future__ import annotations
import json
import logging
import os
from contextlib import contextmanager
from typing import Any
from ..import_utils import unavailable_module
try:
import gevent
except ImportError as ex:
gevent = unavailable_module(ex)
from blissoda.resources import resource_filename
from ..bliss_globals import current_session # type: ignore
from ..ewoks_utils import get_future
from ..ewoks_utils import submit
from ..flint.access import WithFlintAccess
from ..persistent.parameters import ParameterInfo
from ..processor import BaseProcessor
from ..processor import BlissScanType # type: ignore
from ..utils.directories import get_dataset_processed_dir
from .plots import RixsPlot
logger = logging.getLogger(__name__)
[docs]
class RixsPlotter(
BaseProcessor,
parameters=[
ParameterInfo("workflow", category="workflows"),
ParameterInfo("data_mappings", category="workflows"),
ParameterInfo("filename", category="workflows"),
ParameterInfo("start_scan_id", category="workflows"),
ParameterInfo("stop_scan_id", category="workflows"),
ParameterInfo("conc_corr_scan_id", category="workflows"),
ParameterInfo("y_axis", category="workflows"),
],
):
"""Online plotter for RIXS planes displayed in a Flint widget.
Typically instantiated in the Bliss session setup script.
Use ``enable()`` / ``disable()`` to start/stop automatic plotting on each
new scan, or call the instance directly with a scan range::
rixs_plotter(start_scan_id, stop_scan_id)
Call ``reset()`` before ``enable()`` to start a new map from scratch.
The ``y_axis`` attribute can be set to either:
- ``"emission energy"`` (default): plots emission energy on Y.
- ``"energy transfer"``: plots (incident energy - emission energy) on Y.
"""
DEFAULT_PARAMETERS = {
"workflow": resource_filename("id26", "rixs_plotter.json"),
"data_mappings": {
"x": ".1/measurement/elapsed_time",
"y": ".1/instrument/positioners/xes_en",
"signal": ".1/measurement/sim_gaussian_1",
},
"filename": None,
"start_scan_id": None,
"stop_scan_id": None,
"conc_corr_scan_id": None,
"y_axis": "emission energy",
"trigger_at": "END",
"_enabled": False,
"_clear_plot": True,
}
def __init__(
self,
config: dict[str, Any] | None = None,
defaults: dict[str, Any] | None = None,
**deprecated_defaults: dict[str, Any],
) -> None:
defaults = self._merge_defaults(deprecated_defaults, defaults)
for parameter, value in self.DEFAULT_PARAMETERS.items():
defaults.setdefault(parameter, value)
super().__init__(config=config, defaults=defaults)
self.handler = RixsHandler()
[docs]
def reset(self):
for parameter, value in self.DEFAULT_PARAMETERS.items():
if parameter == "_enabled" and self._enabled:
self.disable() # Only disable if currently enabled.
elif parameter == "data_mappings":
continue # Do not reset data mappings.
else:
setattr(self, parameter, value)
self.handler.clear()
def _get_current_scan_id(self) -> None | int:
if current_session is None:
raise ValueError("No current session available to get the scan from.")
try:
scan = current_session.scans[-1]
except (AttributeError, IndexError):
raise ValueError("No scans found in the current session.")
scan_id = scan.scan_info["scan_nb"]
if not isinstance(scan_id, int):
raise ValueError("The id of the current scan is not a valid integer.")
return scan_id
def _get_start_scan_id(self) -> None | int:
return self._get_parameter("start_scan_id") # type: ignore
def _set_start_scan_id(self, value: None | int) -> None:
if value is None:
self._set_parameter("start_scan_id", None)
return
if value < 0:
current_scan_id = self._get_current_scan_id()
if not isinstance(current_scan_id, int):
raise ValueError(
"The id of the current scan must be a valid integer before"
" using negative start_scan_id."
)
value = value + current_scan_id + 1
if value == 0:
raise ValueError("The start_scan_id cannot be zero.")
self._set_parameter("start_scan_id", value)
def _get_stop_scan_id(self) -> None | int:
return self._get_parameter("stop_scan_id") # type: ignore
def _set_stop_scan_id(self, value: None | int) -> None:
if value is None:
self._set_parameter("stop_scan_id", None)
return
if value < 0:
current_scan_id = self._get_current_scan_id()
if not isinstance(current_scan_id, int):
raise ValueError(
"The id of the current scan must be a valid integer before"
" using negative stop_scan_id."
)
value = value + current_scan_id + 1
if value == 0:
raise ValueError("The stop_scan_id cannot be zero.")
self._set_parameter("stop_scan_id", value)
def _get_conc_corr_scan_id(self) -> None | int:
return self._get_parameter("conc_corr_scan_id") # type: ignore
def _get_y_axis(self) -> str:
return self._get_parameter("y_axis") # type: ignore
def _set_conc_corr_scan_id(self, value: None | int) -> None:
if value is None:
self._set_parameter("conc_corr_scan_id", None)
return
if value < 0:
raise ValueError(
"The id of the concentration correction scan cannot be negative."
)
if value == 0:
raise ValueError(
"The id of the concentration correction scan cannot be zero."
)
self._set_parameter("conc_corr_scan_id", value)
def _get_workflow(self) -> dict:
with open(self.workflow) as wf:
return json.load(wf)
def _get_workflow_inputs(self, scan) -> list:
if scan is not None:
current_scan_id = self._get_current_scan_id()
if self._get_start_scan_id() is None:
self._set_start_scan_id(current_scan_id)
self._set_stop_scan_id(current_scan_id)
return [
{
"name": "data_mappings",
"value": dict(self.data_mappings),
},
{
"name": "filename",
"value": self._get_filename(scan),
},
{
"name": "start_scan_id",
"value": self._get_start_scan_id(),
},
{
"name": "stop_scan_id",
"value": self._get_stop_scan_id(),
},
{
"name": "conc_corr_scan_id",
"value": self._get_conc_corr_scan_id(),
},
{
"name": "y_axis",
"value": self._get_y_axis(),
},
]
def _get_filename(self, scan=None) -> str:
if self.filename is not None:
return self.filename
if scan is not None:
filename = scan.scan_info.get("filename")
if filename is not None:
self.filename = filename
return filename
else:
raise ValueError("Scan has no filename.")
if current_session is None:
raise ValueError("No current session available to get the filename from.")
return current_session.scan_saving.filename
def _trigger_workflow_on_new_scan(self, scan: BlissScanType | None) -> None:
workflow = self._get_workflow()
workflow_inputs = self._get_workflow_inputs(scan)
start_scan_id = self._get_start_scan_id()
stop_scan_id = self._get_stop_scan_id()
filename = self._get_filename(scan)
out_dirname = get_dataset_processed_dir(filename)
kwargs = {
"inputs": workflow_inputs,
"outputs": [{"all": False}],
}
if os.path.exists(out_dirname):
kwargs["convert_destination"] = os.path.join(
out_dirname,
"workflows",
f"rixs_plotter_{start_scan_id}_to_{stop_scan_id}.json",
)
future = submit(
args=(workflow,),
kwargs=kwargs,
)
clear_plot = self._get_parameter("_clear_plot")
# Consolidate the workflow inputs into a single dictionary.
workflow_inputs = {
workflow_input["name"]: workflow_input["value"]
for workflow_input in workflow_inputs
}
self.handler.handle_workflow_result(
future.task_id, clear_plot=clear_plot, workflow_inputs=workflow_inputs
)
@contextmanager
def _temporary_scan_ids(
self,
start_scan_id: int,
stop_scan_id: int,
conc_corr_scan_id: int | None = None,
):
"""Temporarily override scan IDs, restoring previous values on exit."""
prev_start_scan_id = self._get_start_scan_id()
prev_stop_scan_id = self._get_stop_scan_id()
prev_conc_corr_scan_id = self._get_conc_corr_scan_id()
self._set_start_scan_id(start_scan_id)
self._set_stop_scan_id(stop_scan_id)
self._set_conc_corr_scan_id(conc_corr_scan_id)
try:
yield
finally:
self._set_start_scan_id(prev_start_scan_id)
self._set_stop_scan_id(prev_stop_scan_id)
self._set_conc_corr_scan_id(prev_conc_corr_scan_id)
def __call__(
self,
start_scan_id: int,
stop_scan_id: int,
conc_corr_scan_id: int | None = None,
clear_plot: bool = True,
) -> None:
if start_scan_id > stop_scan_id:
logger.warning(
f"Invalid scan indices; the start_scan_id ({start_scan_id}) "
f"should be smaller than the stop_scan_id ({stop_scan_id})."
)
return
self._set_parameter("_clear_plot", clear_plot)
with self._temporary_scan_ids(start_scan_id, stop_scan_id, conc_corr_scan_id):
try:
self._trigger_workflow_on_new_scan(None)
except Exception as e:
logger.warning(e, exc_info=True)
[docs]
class RixsHandler(WithFlintAccess):
def __init__(self) -> None:
super().__init__()
self._workflow_inputs = None
self._result = None
[docs]
def clear(self):
plot = self._get_plot()
plot.submit("clear")
plot.xlabel = "X"
plot.ylabel = "Y"
plot.title = ""
[docs]
def handle_workflow_result(self, *args, **kwargs) -> None:
"""Handle workflow result in a background task to recover data."""
gevent.spawn(self._handle_workflow_result, *args, **kwargs)
def _handle_workflow_result(
self, task_id, timeout: int = 60, clear_plot: bool = True, workflow_inputs=None
) -> None:
logger.info(f"Recovering processed data for task {task_id}")
try:
result = get_future(task_id).get(timeout=timeout)
self._result = result
# Merge workflow inputs back into the result for reference.
self._workflow_inputs = workflow_inputs
self.plot_data(clear_plot=clear_plot)
except Exception as e:
if not str(e):
logger.error("The RIXS plotter workflow has failed.")
else:
logger.error(e)
self._result = None
def _get_plot(self, plot_id="Rixs Plot", plot_cls=RixsPlot):
return super()._get_plot(plot_id, plot_cls)
def _get_axes_labels(self) -> dict[str, str]:
axes_labels = {"x": "X", "y": "Y"}
if not self._workflow_inputs:
return axes_labels
data_mappings = self._workflow_inputs.get("data_mappings", {})
labels = self._data_mappings_to_labels(data_mappings)
xlabel = labels.get("x", "X")
ylabel = labels.get("y", "Y")
axes_labels["x"] = f"{xlabel} (keV)"
y_axis = self._workflow_inputs.get("y_axis", "emission energy")
if y_axis == "energy transfer":
axes_labels["y"] = f"{xlabel} - {ylabel} (keV)"
else:
axes_labels["y"] = f"{ylabel} (keV)"
return axes_labels
def _get_filename(self):
if self._workflow_inputs is None:
return ""
filename = self._workflow_inputs.get("filename", "")
return filename
def _get_title(self) -> str:
if not self._workflow_inputs:
return "Rixs Plot"
filename = self._get_filename()
labels = self._data_mappings_to_labels(
self._workflow_inputs.get("data_mappings", {})
)
if labels["monitor"]:
return f"{filename} \n {labels['signal']} / {labels['monitor']}"
return f"{filename} \n {labels['signal']}"
def _data_mappings_to_labels(self, data_mappings) -> dict[str, str]:
"""Convert data mapping paths to human-readable labels.
For each key in data_mappings (e.g., "x", "y", "signal"), extract the last
component of the path (after the final "/") as the label. If no mappings
are provided, return default labels.
"""
labels = {"x": "X", "y": "Y", "signal": "", "monitor": ""}
if not data_mappings:
return labels
labels.update(
{
k: v.split("/")[-1]
for k, v in data_mappings.items()
if k in ["x", "y", "signal", "monitor"]
}
)
return labels
def _get_result(self) -> dict[str, Any] | None:
return self._result
[docs]
def plot_data(self, clear_plot: bool = True):
result = self._get_result()
if not result or self._workflow_inputs is None:
return
plot = self._get_plot()
if clear_plot:
plot.submit("clear")
x, y, signal = result["x"], result["y"], result["signal"]
plot.title = self._get_title()
axes_labels = self._get_axes_labels()
plot.xlabel = axes_labels["x"]
plot.ylabel = axes_labels["y"]
plot.set_data(
signal,
origin=(x[0], y[0]),
scale=(x[1] - x[0], y[1] - y[0]),
resetzoom=False,
)
plot.set_colormap(lut="viridis")
plot.submit("setKeepDataAspectRatio", False)