from __future__ import annotations
import logging
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from ewoksutils.task_utils import task_inputs
from silx.io import h5py_utils
from ..ewoks_utils import get_future
from ..ewoks_utils import submit
from ..import_utils import unavailable_function
from ..persistent.parameters import ParameterInfo
from ..processor import BaseProcessorWithPlotting
from ..processor import BlissScanType
from ..utils.directories import get_filename
from ..utils.directories import master_output_filename
from ..utils.directories import workflow_destination
from ..xrpd.plotter import XrpdPlotter
from ..xrpd.utils import get_integrated_nxdata
try:
from id09.status import get_xray_energy
except ImportError as ex:
get_xray_energy = unavailable_function(ex)
logger = logging.getLogger(__name__)
def _to_energy(energy: str | float | None) -> float:
if energy in ("automatic", "auto", None):
return "automatic"
return float(energy)
[docs]
class TxsProcessor(
BaseProcessorWithPlotting,
parameters=[
ParameterInfo("queue", category="workflows"),
ParameterInfo("distance", category="Txs", doc="meter", validator=float),
ParameterInfo("center", category="Txs", doc="pixel (hor, ver)"),
ParameterInfo("binning", category="Txs", doc="(hor, ver)"),
ParameterInfo("detector", category="Txs", validator=str),
ParameterInfo("pixel", category="Txs", doc="meter (hor, ver)"),
ParameterInfo("energy", category="Txs", doc="eV", validator=_to_energy),
ParameterInfo("integrate1d_options", category="Txs"),
],
):
plotter_class = XrpdPlotter
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
defaults: Optional[Dict[str, Any]] = None,
**deprecated_defaults: Dict[str, Any],
) -> None:
defaults = self._merge_defaults(deprecated_defaults, defaults)
defaults.setdefault("queue", "online")
defaults.setdefault("distance", 1)
defaults.setdefault("center", (960, 960))
defaults.setdefault("binning", 2)
defaults.setdefault("detector", "rayonix")
defaults.setdefault("pixel", None)
defaults.setdefault("energy", "automatic")
defaults.setdefault("integrate1d_options", dict())
super().__init__(config=config, defaults=defaults)
def _trigger_workflow_on_new_scan(self, scan: BlissScanType) -> None:
self.trigger_workflow_on_new_scan(scan)
[docs]
def trigger_workflow_on_new_scan(self, scan: BlissScanType) -> None:
if not self.scan_requires_processing(scan):
return None
workflow = self.get_workflow(scan)
kwargs = {"inputs": self.get_inputs(scan), "outputs": [{"all": False}]}
if scan.scan_info.get("save"):
kwargs["convert_destination"] = workflow_destination(scan)
future = submit(args=(workflow,), kwargs=kwargs, queue=self.queue)
future = get_future(future.uuid)
if self._plotter:
scan_nb = scan.scan_info.get("scan_nb")
start_time = scan.scan_info.get("start_time")
scan_name = self._get_scan_name(scan)
# TODO: Do not hard-code the output url
if scan.scan_info.get("save"):
output_url = (
f"{master_output_filename(scan)}::/{scan_nb}.1/integrate/integrated"
)
else:
output_url = None
self._plotter.handle_workflow_result(
future=future,
timestamp=start_time,
lima_name=self.detector,
scan_name=scan_name,
output_url=output_url,
)
def _get_scan_name(self, scan) -> str:
scan_nb = scan.scan_info.get("scan_nb")
return f"{scan.name}: {scan_nb}.1 {scan.name}"
def _get_txs_task_identifier(self, scan: BlissScanType):
if scan.scan_info.get("save"):
return "ewokstxs.tasks.txs.TxsTask"
return "ewokstxs.tasks.txs.TxsTaskWithoutSaving"
[docs]
def get_workflow(self, scan: BlissScanType):
return {
"graph": {"id": "txs"},
"nodes": [
{
"id": "txs_task",
"task_type": "class",
"task_identifier": self._get_txs_task_identifier(scan),
}
],
}
[docs]
def scan_requires_processing(self, scan: BlissScanType) -> bool:
return f"{self.detector}:image" in scan.scan_info.get("channels", dict())
[docs]
@h5py_utils.retry()
def get_data_keys(self, scan: BlissScanType, lima_name: str):
filename = master_output_filename(scan)
with h5py_utils.File(filename, mode="r") as root:
nxdata = get_integrated_nxdata(root, scan)
return [f"{lima_name}:{key}" for key in nxdata.keys()]
[docs]
@h5py_utils.retry()
def get_data(self, scan: BlissScanType, channel: str, idx=tuple()):
"""
Get processed data for a scan. Tries to mirror the existing `scan.get_data` for raw data.
A list of available channels for a given lima name can be retrieved using `get_data_keys`.
Ex:
>>> scan = loopscan(5, 0.1, difflab6)
>>> xrpd_processor.get_data_keys(scan, 'difflab6')
['difflab6:intensity', 'difflab6:points', 'difflab6:q']
>>> xrpd_processor.get_data(scan, 'difflab6:intensity')
array([...], dtype=float32)
"""
lima_name, field = channel.split(":")
filename = master_output_filename(scan)
with h5py_utils.File(filename, mode="r") as root:
nxdata = get_integrated_nxdata(root, scan)
if field not in nxdata:
raise KeyError(
f"{channel} is not a channel of this processing. Possible channels: {[f'{lima_name}:{key}' for key in nxdata.keys()]}."
)
return nxdata[field][idx]