"""EXAFS plotting in Flint."""
import logging
import os
from collections import OrderedDict
from typing import Dict
from typing import List
from typing import Sequence
from typing import Tuple
from ewoksjob.client import submit
from ..flint import FlintClient
from ..flint.access import WithFlintAccess
from ..flint.colors import ColorCycler
from ..utils.directories import get_dataset_processed_dir
from .plots import ExafsPlot
from .types import ExafsPlotWorkflowParameters
from .types import ExafsSplitWorkflowParameters
from .types import ScanInfo
from .types import SubScanInfo
from .types import XasPlotData
from .types import XasSubscanData
logger = logging.getLogger(__name__)
[docs]
class ExafsPlotter(WithFlintAccess):
"""Manage EXAFS plots in Flint based on Ewoks workflow results."""
def __init__(self, number_of_scans: int = 0, queue: str = "celery") -> None:
super().__init__()
self._number_of_scans = number_of_scans
self._queue = queue
# Fixed parameters
self._plot_names = {
"flatten_mu": "mu",
"chi_weighted_k": "chi",
"ft_mag": "ft",
"noise_savgol": "noise",
}
# Runtime data
self._scan_infos: Dict[str, ScanInfo] = OrderedDict()
self._color_cycler = ColorCycler(max_colors=number_of_scans + 1)
@property
def number_of_scans(self):
"""Maximum number of scans to be plotted"""
return self._number_of_scans
@number_of_scans.setter
def number_of_scans(self, value):
number_of_scans = max(value, 0)
self._number_of_scans = number_of_scans
self._color_cycler.max_colors = number_of_scans + 1
self.sync_plots()
self.refresh()
@property
def scan_ids(self) -> List[str]:
return list(self._scan_infos)
[docs]
def clear(self) -> None:
"""Remove all scan curves in all plots"""
self._get_plot().clear()
for scan_info in self._scan_infos.values():
for subscan_info in scan_info.subscans:
subscan_info.enabled = False
[docs]
def refresh(self) -> None:
"""Refresh all plots with the current processed data"""
self._color_cycler.reset()
for scan_id, scan_info in self._scan_infos.items():
for subscan_info in scan_info.subscans:
subscan_info.updated = True
subscan_info.color = self._color_cycler.next()
self._update_scan_plot(scan_id)
[docs]
def create_scan_id(self, filename: str, scan_number: int) -> str:
return f"{filename}|{scan_number}"
[docs]
def ensure_scan_infos(
self, filename: str, scan_number: int
) -> Tuple[str, ScanInfo]:
scan_id = self.create_scan_id(filename, scan_number)
scan_info = self._scan_infos.get(scan_id)
if scan_info:
return scan_id, scan_info
scan_info = ScanInfo(
filename=filename, scan_number=scan_number, subscans=[], xas_results=[]
)
self._scan_infos[scan_id] = scan_info
return scan_id, scan_info
def _get_total_number_of_subscans(self) -> int:
n_subscans = 0
for scan_info in self._scan_infos.values():
n_subscans += len(scan_info.subscans)
return n_subscans
[docs]
def sync_plots(self):
"""Synchronize workflow results in cache with plots in Flint."""
n_subscans = self._get_total_number_of_subscans()
max_subscans = self.number_of_scans
min_enabled_subscan_idx = max(n_subscans - max_subscans, 0)
subscan_index = 0
for scan_info in self._scan_infos.values():
for subscan_info, xasplotdata in zip(
scan_info.subscans, scan_info.xas_results
):
enabled = subscan_index >= min_enabled_subscan_idx
needs_refresh = subscan_info.updated or subscan_info.enabled != enabled
if needs_refresh:
if enabled:
self._get_plot().update_scan(
subscan_info.legend,
xasplotdata,
subscan_info.color,
)
else:
self._get_plot().remove_scan(subscan_info.legend)
subscan_info.enabled = enabled
subscan_info.updated = False
subscan_index += 1
[docs]
def remove_scan(self, legend: str) -> None:
"""Disable subscan and remove from Flint."""
removed = False
for scan_info in self._scan_infos.values():
for subscan_info in scan_info.subscans:
if subscan_info.legend == legend:
if subscan_info.enabled:
subscan_info.enabled = False
self._get_plot().remove_scan(subscan_info.legend)
removed = True
break
if removed:
self.purge_scan_infos()
[docs]
def purge_scan_infos(self, keep_scan_ids: Sequence[str] = tuple()) -> None:
"""Remove cache from scans that have no enabled sub-scans or not processing results."""
for scan_id, scan_info in list(self._scan_infos.items()):
if scan_id in keep_scan_ids:
continue
# Disable all subscans when the scan has no data
if not scan_info.xas_results:
for subscan_info in scan_info.subscans:
if subscan_info.enabled:
subscan_info.enabled = False
self._get_plot().remove_scan(subscan_info.legend)
# Delete the scan when it has no enabled subscans
has_enabled_subscan = any(
subscan_info.enabled for subscan_info in scan_info.subscans
)
if not has_enabled_subscan:
del self._scan_infos[scan_id]
def _update_scan_plot(self, scan_id: str) -> None:
"""Update all scan curves in all plots"""
scan_info = self._scan_infos.get(scan_id)
if scan_info is not None:
self.sync_results(scan_info)
self.sync_plots()
def _on_flint_restart(self, flint_client: FlintClient) -> None:
super()._on_flint_restart(flint_client)
for scan_info in self._scan_infos.values():
self.sync_results(scan_info)
self.sync_plots()
[docs]
def sync_results(self, scan_info: ScanInfo) -> None:
"""Update workflow results in cache with processing results.
Blocks when the scan has a pending future.
"""
if scan_info.plot_future is None:
return
try:
workflow_results = scan_info.plot_future.result()
except Exception:
# Workflow failed
return
finally:
scan_info.plot_future = None
xas_results = [
XasSubscanData(
**{
self._plot_names[plot_name]: XasPlotData(**plot_data)
for plot_name, plot_data in data.items()
}
)
for data in workflow_results["plot_data"]
]
if not xas_results:
# Workflow succeeded but did not return any results
return
for idx, xasplotdata in enumerate(
xas_results, scan_info.min_subscan_index_to_process
):
if idx < len(scan_info.xas_results):
scan_info.xas_results[idx] = xasplotdata
scan_info.subscans[idx].updated = True
else:
basename = os.path.basename(os.path.dirname(scan_info.filename))
legend = f"{basename}: {scan_info.scan_number}.{idx + 1}"
subscan_info = SubScanInfo(
legend=legend, color=self._color_cycler.next(), enabled=False
)
scan_info.subscans.append(subscan_info)
scan_info.xas_results.append(xasplotdata)
def _submit_plot_workflow(
self,
scan_id: str,
parameters: ExafsPlotWorkflowParameters,
reprocess: bool = False,
) -> None:
"""Submit the data processing for a scan"""
scan_info = self._scan_infos.get(scan_id, None)
if scan_info is None:
return
scan_info.reprocess_all = reprocess
input_information = {
"channel_url": f"{scan_info.scan_url}/measurement/{parameters.energy_name}",
"spectra_url": f"{scan_info.scan_url}/measurement/{parameters.mu_name}",
"energy_unit": parameters.energy_unit,
}
if parameters.mon_name is not None:
input_information["mu_ref_url"] = (
f"{scan_info.scan_url}/measurement/{parameters.mon_name}"
)
if parameters.min_log is not None:
input_information["min_log"] = parameters.min_log
if scan_info.multi_xas_scan:
input_information["is_concatenated"] = True
input_information["trim_concatenated_n_points"] = parameters.trim_n_points
input_information["skip_concatenated_n_spectra"] = (
scan_info.min_subscan_index_to_process
)
input_information["concatenated_spectra_section_size"] = (
scan_info.multi_xas_subscan_size
)
plot_inputs = [
{
"task_type": "ReadXasObject",
"name": "input_information",
"value": input_information,
},
{
"task_type": "PlotSpectrumData",
"name": "plot_names",
"value": list(self._plot_names),
},
]
scan_info.plot_future = submit(
args=(parameters.workflow,),
kwargs={"inputs": plot_inputs},
queue=self._queue,
)
def _submit_split_workflow(
self,
scan_id: str,
parameters: ExafsSplitWorkflowParameters,
reprocess: bool = False,
) -> None:
"""Submit the data processing for a scan"""
scan_info = self._scan_infos.get(scan_id, None)
if scan_info is None or not scan_info.multi_xas_scan:
return
out_dirname = get_dataset_processed_dir(scan_info.filename)
h5_basename = os.path.basename(scan_info.filename)
h5_stem, _ = os.path.splitext(h5_basename)
out_filename = os.path.join(out_dirname, h5_basename)
convert_destination = os.path.join(
out_dirname, "workflows", f"{h5_stem}_scan{scan_info.scan_number:04d}.json"
)
if out_filename.startswith("/data/scisoft/ewoks"):
out_filename = out_filename.replace("/data/scisoft/ewoks", "/tmp_14_days")
split_inputs = [
{
"task_type": "SplitBlissScan",
"name": "filename",
"value": scan_info.filename,
},
{
"task_type": "SplitBlissScan",
"name": "scan_number",
"value": scan_info.scan_number,
},
{
"task_type": "SplitBlissScan",
"name": "monotonic_channel",
"value": parameters.monotonic_channel,
},
{
"task_type": "SplitBlissScan",
"name": "subscan_size",
"value": scan_info.multi_xas_subscan_size,
},
{
"task_type": "SplitBlissScan",
"name": "trim_n_points",
"value": parameters.trim_n_points,
},
{
"task_type": "SplitBlissScan",
"name": "wait_finished",
"value": parameters.scan_complete,
},
{
"task_type": "SplitBlissScan",
"name": "out_filename",
"value": out_filename,
},
{
"task_type": "SplitBlissScan",
"name": "counter_group",
"value": "measurement",
},
]
scan_info.split_future = submit(
args=(parameters.workflow,),
kwargs={"inputs": split_inputs, "convert_destination": convert_destination},
)
[docs]
def execute_and_plot(
self,
scan_id: str,
plot_parameters: ExafsPlotWorkflowParameters,
split_parameters: ExafsSplitWorkflowParameters,
reprocess: bool = False,
) -> None:
self._submit_plot_workflow(scan_id, plot_parameters, reprocess=reprocess)
self._submit_split_workflow(scan_id, split_parameters, reprocess=reprocess)
self._update_scan_plot(scan_id)
def _get_plot(self) -> ExafsPlot:
"""Launches Flint and creates the plot when either is missing"""
return super()._get_plot("EXAFS", ExafsPlot)