from __future__ import annotations
from contextlib import contextmanager
from pathlib import Path
from typing import Any
from typing import Dict
from typing import Generator
from typing import List
from typing import Optional
from typing import Union
import fabio
import numpy
from ewoksutils.task_utils import task_inputs
from ..bliss_globals import setup_globals
from ..import_utils import unavailable_module
from ..persistent.parameters import ParameterInfo
from ..streamline.scanner import StreamlineScanner
from ..utils import validators
from ..utils.directories import get_dataset_processed_dir
from . import optimize_exposure
from .stop_scan_preset import StopIntegrateSum
from .utils import ensure_shutter_open
try:
import streamline_changer.sample_changer as sc
except ModuleNotFoundError as ex:
sc = unavailable_module(ex)
[docs]
class Id31StreamlineScanner(
StreamlineScanner,
parameters=[
ParameterInfo(
"optimize_pixel_value", category="exposure/attenuator", validator=float
),
ParameterInfo(
"optimize_nb_frames", category="exposure/attenuator", validator=int
),
ParameterInfo(
"optimize_min_exposure_time",
category="exposure/attenuator",
validator=float,
),
ParameterInfo(
"optimize_max_exposure_time",
category="exposure/attenuator",
validator=float,
),
ParameterInfo("default_attenuator", category="exposure/attenuator"),
ParameterInfo("attenuator_name", category="names"),
ParameterInfo("newflat", category="Flat-field", validator=validators.is_file),
ParameterInfo("oldflat", category="Flat-field", validator=validators.is_file),
ParameterInfo("flat_enabled", category="Flat-field", validator=bool),
ParameterInfo("optimize_exposure_per", category="robust vs. speed"),
ParameterInfo("rockit_distance", category="sample changer"),
ParameterInfo(
"optimize_attenuator",
category="exposure/attenuator",
validator=bool,
),
ParameterInfo("optimize_mask_file", category="exposure/attenuator"),
ParameterInfo("baguette_type", category="sample changer"),
ParameterInfo("auto_stop_acc_mode", category="auto stop", validator=bool),
ParameterInfo("auto_stop_workflow_path", category="auto stop"),
ParameterInfo("auto_stop_threshold", category="auto stop", validator=float),
ParameterInfo("auto_stop_scan_npoints", category="auto stop", validator=int),
ParameterInfo("detector_saturation", category="auto stop", validator=float),
ParameterInfo("auto_stop_target_max", category="auto stop", validator=float),
ParameterInfo("auto_stop_target_min", category="auto stop", validator=float),
ParameterInfo("auto_stop_attenuation_mode", category="auto stop"),
ParameterInfo(
"auto_stop_ghost_threshold_per_frame",
category="auto stop",
validator=float,
),
ParameterInfo(
"auto_stop_spottiness_threshold", category="auto stop", validator=float
),
ParameterInfo(
"auto_stop_spotty_safe_atten", category="auto stop", validator=int
),
ParameterInfo(
"auto_stop_spotty_stability_frames", category="auto stop", validator=int
),
ParameterInfo(
"auto_stop_spotty_stability_tol", category="auto stop", validator=float
),
ParameterInfo(
"auto_stop_spotty_extra_atten", category="auto stop", validator=int
),
ParameterInfo(
"auto_stop_metric_timeout", category="auto stop", validator=float
),
ParameterInfo(
"auto_stop_sum_frame_max_counts",
category="auto stop",
validator=float,
),
],
):
# baguette types: step_size in motor user units (1 user unit = 13 mm)
def _get_baguette_config(self, name: str) -> sc.BaguetteConfig:
baguette_types = {
"standard": sc.BAGUETTE_13MM, # 16 samples, 13 mm pitch
"wide": sc.BAGUETTE_25MM, # 8 samples, 24.8 mm pitch
}
if name not in baguette_types:
available = list(baguette_types.keys())
raise ValueError(f"Unknown baguette type '{name}'. Available: {available}")
return baguette_types[name]
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
defaults: Optional[Dict[str, Any]] = None,
**deprecated_defaults: Dict[str, Any],
):
defaults = self._merge_defaults(deprecated_defaults, defaults)
defaults.setdefault("workflow", "streamline_without_calib_with_flat.json")
defaults.setdefault("detector_name", "p3")
defaults.setdefault("attenuator_name", "atten")
defaults.setdefault("sample_changer_name", "streamline_sc")
defaults.setdefault(
"integration_options",
{
"method": "no_csr_ocl_gpu",
"integrator_name": "sigma_clip_ng",
"extra_options": {"max_iter": 3, "thres": 0},
"error_model": "azimuthal", # hybrid gives weird results
"nbpt_rad": 4096,
"unit": "q_nm^-1",
},
)
defaults.setdefault("flat_enabled", True)
defaults.setdefault("newflat", "/data/id31/inhouse/P3/flats.mat")
defaults.setdefault("oldflat", "/data/id31/inhouse/P3/flats_old.mat")
defaults.setdefault("optimize_pixel_value", 1e5)
defaults.setdefault("optimize_nb_frames", 3)
defaults.setdefault("optimize_min_exposure_time", 2)
defaults.setdefault("optimize_max_exposure_time", 4)
defaults.setdefault("optimize_exposure_per", "baguette")
defaults.setdefault("rockit_distance", 0.07)
defaults.setdefault("optimize_attenuator", True)
defaults.setdefault("optimize_mask_file", None)
defaults.setdefault("baguette_type", None) # default: streamline_sc .yml config
defaults.setdefault("auto_stop_acc_mode", False)
defaults.setdefault("auto_stop_workflow_path", None)
defaults.setdefault("auto_stop_threshold", 0.01)
defaults.setdefault("auto_stop_scan_npoints", 10000)
defaults.setdefault("detector_saturation", float(1 << 20))
defaults.setdefault("auto_stop_target_max", None)
defaults.setdefault("auto_stop_target_min", None)
defaults.setdefault("auto_stop_attenuation_mode", "freeze")
defaults.setdefault("auto_stop_ghost_threshold_per_frame", 100_000.0)
defaults.setdefault("auto_stop_spottiness_threshold", 0.1)
defaults.setdefault("auto_stop_spotty_safe_atten", 16)
defaults.setdefault("auto_stop_spotty_stability_frames", 5)
defaults.setdefault("auto_stop_spotty_stability_tol", 0.1)
defaults.setdefault("auto_stop_spotty_extra_atten", 2)
defaults.setdefault("auto_stop_metric_timeout", 2.0)
defaults.setdefault("auto_stop_sum_frame_max_counts", 300_000.0)
super().__init__(config=config, defaults=defaults)
self._exposure_conditions: List[
Union[optimize_exposure.ExposureCondition, None]
] = list()
self._fixed_attenuator_position = None
self._optimize_mask_array: Optional[numpy.ndarray] = None
self._update_optimize_mask(self.optimize_mask_file)
# upon Bliss restart push baguette type from Redis to streamline_scanner
# object, which otherwise loads yml config (same as `baguette_type = None`)
_btype = self.baguette_type
if _btype is not None:
self.baguette_type = _btype
@property
def baguette_type(self) -> Optional[str]:
"""Currently active baguette type name, returns None if unset."""
return self._get_parameter("baguette_type")
@baguette_type.setter
def baguette_type(self, value: str) -> None:
"""Set the sample changer to use a different baguette type.
Available types :
- ``"standard"``: 16 samples, 13 mm pitch (step_size=1)
- ``"wide"``: 8 samples, 24.8 mm pitch (step_size=24.8/13)
:param baguette_type: baguette type name (see ``_get_baguette_config()``)
"""
config = self._get_baguette_config(value)
self.sample_changer.configure_baguette(config)
self._set_parameter("baguette_type", value)
print(
f"Baguette type set to '{value}': "
f"{config.number_of_samples} samples, "
f"step_size={config.step_size:.4f} user units"
)
@property
def optimize_exposure_per(self) -> Optional[str]:
return self._get_parameter("optimize_exposure_per")
@optimize_exposure_per.setter
def optimize_exposure_per(self, value: Optional[str]):
if value not in (None, "sample", "baguette"):
raise ValueError("Allowed values are 'sample', 'baguette' or None")
self._set_parameter("optimize_exposure_per", value)
@property
def optimize_mask_file(self) -> Optional[str]:
return self._get_parameter("optimize_mask_file")
@optimize_mask_file.setter
def optimize_mask_file(self, filename: Optional[str]):
self._update_optimize_mask(filename)
def _update_optimize_mask(self, filename: Optional[str]):
if not filename:
self._optimize_mask_array = None
self._set_parameter("optimize_mask_file", None)
return
try:
self._optimize_mask_array = fabio.open(filename).data
except Exception:
print(f"Error: cannot load optimize_mask_file {filename}, reseting it!")
self._optimize_mask_array = None
self._set_parameter("optimize_mask_file", None)
return
self._set_parameter("optimize_mask_file", filename)
[docs]
@contextmanager
def run_context(self):
setup_globals.shopen(
check_pilatus=False
) # check_pilatus = False when the detector was just started
with super().run_context():
yield
[docs]
def load(self):
super().load()
if self.optimize_exposure_per == "baguette" and not self.auto_stop_acc_mode:
self.determine_exposure_conditions()
[docs]
def measure_sample(
self, count_time: float = 1, *args, has_qrcode: bool = True, **kwargs
):
if self.auto_stop_acc_mode:
return self._measure_sample_auto_stop(count_time, *args, **kwargs)
else:
return self._measure_sample_default(
count_time, *args, has_qrcode=has_qrcode, **kwargs
)
def _measure_sample_default(
self, count_time: float = 1, *args, has_qrcode: bool = True, **kwargs
):
with rockit(self.sample_changer.translation, self.rockit_distance):
with self._optimize_sample_exposure(
count_time, has_qrcode=has_qrcode
) as expo_time:
expo_time_max = self.optimize_max_exposure_time
expo_time_min = self.optimize_min_exposure_time
expo_time = max(expo_time_min, min(expo_time, expo_time_max))
if not self.dryrun:
ensure_shutter_open()
try:
return super().measure_sample(expo_time, *args, **kwargs)
except RuntimeError as e:
if "Pilatus protection" in str(e):
print(
f"Skip because of measurement error: {e}. Open shutter again"
)
return
raise
[docs]
def determine_exposure_conditions(self):
"""Pre-define optimal conditions: ascan at fixed attenuator position"""
detector = getattr(setup_globals, self.detector_name)
self._set_attenuation()
self._exposure_conditions = optimize_exposure.optimal_exposure_conditions(
*self.sample_changer.ascan_arguments(),
detector,
tframe=0.2,
desired_counts=self.optimize_pixel_value,
nframes_measure=1,
nframes_default=self.optimize_nb_frames,
reduce_desired_deviation=True,
expose_with_integral_frames=False,
mask=self._optimize_mask_array,
)
def _set_attenuation(self):
if self.default_attenuator is None:
attenuator = getattr(setup_globals, self.attenuator_name)
self.default_attenuator = attenuator.bits
else:
setup_globals.att(self.default_attenuator)
[docs]
def determine_exposure_conditions_individually(self):
"""Pre-define optimal conditions: ct on each sample with adapted attenuator
if the default attenuator position gives too much or too little counts"""
detector = getattr(setup_globals, self.detector_name)
attenuator = getattr(setup_globals, self.attenuator_name)
att_value = attenuator.bits
exposure_conditions = list()
try:
for _ in self.sample_changer.iterate_samples_without_qr():
exposure_conditions.append(self._optimize_exposure_condition(detector))
finally:
setup_globals.att(att_value)
self._exposure_conditions = exposure_conditions
@contextmanager
def _optimize_sample_exposure(
self, count_time: float, has_qrcode: bool = True
) -> Generator[float, None, None]:
"""Selecting the optimal measurement conditions and returning the corresponding
exposure time for the current sample."""
if not self.optimize_exposure_per:
# Optimization is disabled
yield count_time
return
if self.optimize_exposure_per == "baguette":
# Select pre-defined optimization if available
count_time = self._set_exposure_condition()
if count_time is not None:
yield count_time
return
if not has_qrcode:
# No QR-code probably means no sample so do not waste time optimizing
yield count_time
else:
# Optimize condition for this sample individually
detector = getattr(setup_globals, self.detector_name)
attenuator = getattr(setup_globals, self.attenuator_name)
att_value = attenuator.bits
try:
condition = self._optimize_exposure_condition(detector)
yield condition.expo_time
finally:
setup_globals.att(att_value)
def _set_exposure_condition(self) -> float | None:
if not self._exposure_conditions:
self.determine_exposure_conditions()
sample_index = self.sample_changer.current_sample_index
condition = self._exposure_conditions[sample_index]
print(f"Pre-defined optimal exposure conditions: {condition}")
if condition is None:
return None
setup_globals.att(condition.att_position)
return condition.expo_time
def _optimize_exposure_condition(
self, detector
) -> optimize_exposure.ExposureCondition:
return optimize_exposure.optimize_exposure_condition(
detector,
tframe=0.2,
default_att_position=self.default_attenuator,
desired_counts=self.optimize_pixel_value,
dynamic_range=1 << 20,
min_counts_per_frame=0, # take 100
nframes_measure=1,
nframes_default=self.optimize_nb_frames,
reduce_desired_deviation=True,
expose_with_integral_frames=False,
optimize_attenuator=self.optimize_attenuator,
mask=self._optimize_mask_array,
)
def _measure_sample_auto_stop(self, count_time, *args, **kwargs):
"""Measure a sample using a loopscan + StopIntegrateSum preset.
The loopscan runs until the supplied threshold is crossed or npoints is
reached. Sums all accumulated frames and writes a ``*_sum.h5`` file.
On Pilatus protection, the attenuator is stepped up and the loopscan
is retried in a **fresh dataset** so the saturated frames never share
an HDF5 file with the converged attempt (scan number inside any
auto-stop dataset is always ``1.1``); the sample is skipped only when
the attenuator is already at its maximum position.
"""
if not self.auto_stop_workflow_path:
raise RuntimeError("auto_stop_workflow_path is not set.")
if not self.pyfai_config:
raise RuntimeError("pyfai_config is not set.")
if self.dryrun:
print("Dry-run: skip auto-stop measurement")
return None
# set `att` to default if present, otherwise to current atten.bits
self._set_attenuation()
attenuator = getattr(setup_globals, self.attenuator_name)
att_position_max = 31 # matches optimize_exposure.optimize_exposure_condition
while True:
preset = StopIntegrateSum(
workflow_threshold=self.auto_stop_threshold,
workflow_path=self.auto_stop_workflow_path,
detector_name=self.detector_name,
detector_saturation=self.detector_saturation,
pyfai_config_path=self.pyfai_config,
attenuator_name=self.attenuator_name,
frame_target_max=self.auto_stop_target_max,
frame_target_min=self.auto_stop_target_min,
ewoksjob_queue=self.queue,
attenuation_mode=self.auto_stop_attenuation_mode,
ghost_threshold_per_frame=self.auto_stop_ghost_threshold_per_frame,
spottiness_threshold=self.auto_stop_spottiness_threshold,
spotty_safe_atten=self.auto_stop_spotty_safe_atten,
spotty_stability_frames=self.auto_stop_spotty_stability_frames,
spotty_stability_tol=self.auto_stop_spotty_stability_tol,
spotty_extra_atten=self.auto_stop_spotty_extra_atten,
metric_timeout=self.auto_stop_metric_timeout,
sum_frame_max_threshold=self.auto_stop_sum_frame_max_counts,
)
with rockit(self.sample_changer.translation, self.rockit_distance):
ensure_shutter_open()
scan = setup_globals.loopscan(
self.auto_stop_scan_npoints, count_time, *args, run=False, **kwargs
)
scan.acq_chain.add_preset(preset)
try:
scan.run()
return scan
except RuntimeError as e:
if "Pilatus protection" not in str(e):
raise
current_att = attenuator.bits
if current_att < att_position_max:
new_att = min(current_att + 2, att_position_max)
setup_globals.att(new_att)
print(
f"Pilatus protection fired at att={current_att}. "
f"Increasing to {new_att} and retrying."
)
setup_globals.newdataset()
continue
print(
f"Skip because of measurement error: {e}. "
f"Attenuator already at max ({att_position_max})."
)
return None
[docs]
def init_workflow(self, with_autocalibration: bool = False) -> None:
if with_autocalibration:
self.workflow = "streamline_with_calib_with_flat.json"
else:
self.workflow = "streamline_without_calib_with_flat.json"
print(f"Active data processing workflow: {self.workflow}")
def _job_arguments(self, scan_info, processed_metadata: dict):
args, kwargs = super()._job_arguments(scan_info, processed_metadata)
detector_name = self.detector_name
kwargs["inputs"] += task_inputs(
task_identifier="FlatFieldFromEnergy",
inputs={
"newflat": self.newflat,
"oldflat": self.oldflat,
"energy": getattr(setup_globals, self.energy_name).position,
"enabled": self.flat_enabled and detector_name == "p3",
},
)
# Rely on StreamlineScanner to set url and bliss_scan_url as for other SaveNexusPattern1D tasks
kwargs["inputs"] += task_inputs(
task_identifier="SaveNexusPattern1D",
label="save_q_no_sigmaclip_hdf5",
inputs={
"nxprocess_name": f"{detector_name}_integrate_q_no_sigmaclip",
"nxmeasurement_name": f"{detector_name}_integrated_q_no_sigmaclip",
"metadata": {
f"{detector_name}_integrate_q_no_sigmaclip": {
"configuration": {"workflow": self.workflow}
}
},
},
)
# Use workflows from ewoksid31 module
kwargs["load_options"] = {"root_module": "ewoksid31.workflows"}
# if using auto stop mode, replace the image URL to be input into the
# Integrate1D task with the *_sum.h5 file produced by WriteSumFrame,
# which contains the sum of all frames collected by the loopscan
if self.auto_stop_acc_mode:
filename, scan_nb = scan_info.filename, scan_info.scan_nb
stem = Path(filename).stem
sum_filename = str(
Path(get_dataset_processed_dir(filename)) / f"{stem}_sum.h5"
)
sum_path = f"/{scan_nb}.1/measurement/{detector_name}"
for input in kwargs["inputs"]:
if input["name"] == "image":
input["value"] = f"silx://{sum_filename}?path={sum_path}&slice=0"
return args, kwargs
[docs]
@contextmanager
def rockit(motor, distance):
if distance:
try:
with setup_globals.rockit(motor, distance):
print("ROCKING ON")
yield
finally:
print("ROCKING OFF")
else:
print("ROCKING DISABLED")
yield