from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
from typing import Optional
import numpy as np
from pint.errors import PintError
from ...bliss_globals import current_session
from ...ewoks_utils import submit
from ...persistent.parameters import ParameterInfo
from ...processor import BaseProcessor
from ...processor import BlissScanType
from ...resources import resource_filename
from ..utils import calculate_relative_CoR_estimate
from . import backandforth
from . import helical
from . import holotomo
from . import multiturn
from . import standard
from . import zseries
try:
from bliss.physics import units
except ImportError:
from pint import UnitRegistry
units = UnitRegistry()
_TASK_IDENTIFIER = "ewokstomo.tasks.tomobasictonxtomo.TomoBasicToNXtomo"
_SUBSCAN_IMAGE_KEYS = {"tomo:dark": 2, "tomo:flat": 1, "tomo:return_ref": -1}
_PROJECTION_SUBSCAN_KINDS = frozenset(
{
"tomo",
"tomo:step",
"tomo:continuous",
"tomo:sweep",
"tomo:interlaced",
"tomo:multiturns",
"tomo:fulltomo",
}
)
_logger = logging.getLogger(__name__)
def _to_json_serializable(value: Any) -> Any:
if isinstance(value, np.ndarray):
return value.tolist()
if isinstance(value, np.generic):
return value.item()
if isinstance(value, dict):
return {k: _to_json_serializable(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
return [_to_json_serializable(v) for v in value]
return value
[docs]
class CreateNxTomoProcessorBase(
BaseProcessor,
parameters=[
ParameterInfo(
"workflow",
category="workflows",
doc="Workflow file used to create NXtomo from tomo sequence metadata",
),
ParameterInfo("queue", category="workflows"),
ParameterInfo(
"output_path",
category="files",
doc="Optional output .nx path (or output directory)",
),
ParameterInfo(
"_bliss_hdf5_path",
category="files",
doc="HDF5 dataset path (filled automatically)",
),
ParameterInfo(
"offset_mm",
category="estimate_center_of_rotation",
doc="Offset (mm) subtracted from the translation_y motor position",
),
],
):
def _unit_registry(self):
return getattr(units, "ur", units)
def _convert_quantity(
self, value: Any, unit: str, target_unit: str, field_name: str
) -> float:
try:
registry = self._unit_registry()
quantity = registry.Quantity(float(value), unit)
return float(quantity.to(target_unit).magnitude)
except (AttributeError, PintError, ValueError) as exc:
raise ValueError(f"Unsupported {field_name} unit: {unit!r}") from exc
def __init__(
self,
config: Optional[dict[str, Any]] = None,
defaults: Optional[dict[str, Any]] = None,
) -> None:
defaults = defaults or {}
defaults.setdefault("trigger_at", "PREPARED")
defaults.setdefault("workflow", "create_nxtomo.json")
defaults.setdefault("queue", None)
defaults.setdefault("output_path", None)
defaults.setdefault("offset_mm", 0.0)
super().__init__(config=config, defaults=defaults)
def _sequence_name(self, entry: Any) -> str:
technique = entry.get("technique", {})
technique_scan = technique.get("scan", {})
sequence = technique_scan.get("sequence")
if sequence is not None:
return str(sequence)
scan_category = technique.get("scan_category")
if scan_category is not None:
return str(scan_category)
return str(entry.get("title", ""))
def _is_standard_sequence(self, entry: Any) -> bool:
return standard.matches(self, entry)
def _is_zseries_sequence(self, entry: Any) -> bool:
return zseries.matches(self, entry)
def _is_holotomo_subsequence(self, entry: Any) -> bool:
return holotomo.matches(self, entry)
def _is_helical_sequence(self, entry: Any) -> bool:
return helical.matches(self, entry)
def _is_supported_sequence(self, entry: Any) -> bool:
return (
self._is_holotomo_subsequence(entry)
or self._is_zseries_sequence(entry)
or self._is_standard_sequence(entry)
or self._is_helical_sequence(entry)
or self._is_backandforth_sequence(entry)
or self._is_multiturn_sequence(entry)
)
def _is_backandforth_sequence(self, entry: Any) -> bool:
return backandforth.matches(self, entry)
def _is_multiturn_sequence(self, entry: Any) -> bool:
return multiturn.matches(self, entry)
def _sequence_module(self, entry: Any):
if self._is_holotomo_subsequence(entry):
return holotomo
if self._is_zseries_sequence(entry):
return zseries
if self._is_helical_sequence(entry):
return helical
if self._is_backandforth_sequence(entry):
return backandforth
if self._is_multiturn_sequence(entry):
return multiturn
return standard
def _module_hook(self, entry: Optional[Any], hook_name: str) -> Optional[Any]:
if entry is None:
return None
module = self._sequence_module(entry)
return getattr(module, hook_name, None)
def _default_dataset_stem(
self, bliss_path: str, label: Optional[str] = None
) -> str:
stem = Path(bliss_path).with_suffix("").name
if label:
return f"{stem}_{label}"
return stem
def _default_output_stem(self, bliss_path: str, label: Optional[str] = None) -> str:
stem = f"{Path(bliss_path).with_suffix('').name}_seq_start"
if label:
return f"{stem}_{label}"
return stem
def _default_dataset_processed_dir(
self, bliss_path: str, label: Optional[str] = None
) -> Path:
processed_path = bliss_path.replace("RAW_DATA", "PROCESSED_DATA")
nx_path = Path(processed_path).with_suffix("")
dataset_dir = nx_path.parent
if label:
dataset_dir = dataset_dir.with_name(f"{dataset_dir.name}_{label}")
return dataset_dir
def _dataset_stem(
self,
bliss_path: str,
label: Optional[str] = None,
entry: Optional[Any] = None,
) -> str:
hook = self._module_hook(entry, "dataset_stem")
if hook is not None:
return str(hook(self, bliss_path, label))
return self._default_dataset_stem(bliss_path, label)
def _output_stem(
self,
bliss_path: str,
label: Optional[str] = None,
entry: Optional[Any] = None,
) -> str:
hook = self._module_hook(entry, "output_stem")
if hook is not None:
return str(hook(self, bliss_path, label))
return self._default_output_stem(bliss_path, label)
def _dataset_processed_dir(
self,
bliss_path: str,
label: Optional[str] = None,
entry: Optional[Any] = None,
) -> Path:
hook = self._module_hook(entry, "dataset_processed_dir")
if hook is not None:
return Path(hook(self, bliss_path, label))
return self._default_dataset_processed_dir(bliss_path, label)
def _build_output_path(
self,
bliss_path: str,
label: Optional[str] = None,
entry: Optional[Any] = None,
) -> str:
hook = self._module_hook(entry, "build_output_path")
if hook is not None:
return str(hook(self, bliss_path, label))
dataset_dir = self._dataset_processed_dir(bliss_path, label, entry)
projections_dir = dataset_dir / "projections"
return str(
projections_dir / f"{self._output_stem(bliss_path, label, entry)}.nx"
)
def _get_workflows_dir(
self,
dataset_filename: str,
label: Optional[str] = None,
entry: Optional[Any] = None,
) -> Path:
hook = self._module_hook(entry, "workflows_dir")
if hook is not None:
return Path(hook(self, dataset_filename, label))
return (
self._dataset_processed_dir(dataset_filename, label, entry)
/ "workflows"
/ "gallery"
)
def _get_workflow_upload_parameters(
self,
scan: BlissScanType,
label: Optional[str] = None,
entry: Optional[Any] = None,
) -> Optional[dict[str, Any]]:
if not scan.scan_info.get("save"):
return None
scan_saving = current_session.scan_saving
filename = scan.scan_info.get("filename") or scan_saving.filename
metadata = {"Sample_name": scan_saving.dataset["Sample_name"]}
workflows_dir = self._get_workflows_dir(filename, label, entry)
raw_directory = str(Path(filename).parent)
return {
"beamline": scan_saving.beamline,
"proposal": scan_saving.proposal_name,
"dataset": "workflows",
"path": str(workflows_dir),
"raw": [raw_directory],
"metadata": metadata,
}
[docs]
def workflow_destination(
self, label: Optional[str] = None, entry: Optional[Any] = None
) -> str:
workflows_dir = self._get_workflows_dir(self._bliss_hdf5_path, label, entry)
filename = f"{self._output_stem(self._bliss_hdf5_path, label, entry)}_nx.json"
return str(workflows_dir / filename)
def _is_tomo_sequence_scan(self, scan: BlissScanType) -> bool:
return self._is_supported_sequence(self._entry(scan))
def _sequence_root_scan_number(self, scan_info: dict[str, Any]) -> int:
scan_nb = scan_info.get("scan_nb")
if scan_nb is None:
raise ValueError("scan_nb is required")
index_in_sequence = scan_info.get("index_in_sequence")
if index_in_sequence is None:
return int(scan_nb)
return int(scan_nb) - int(index_in_sequence) - 1
def _entry(self, scan: BlissScanType) -> dict[str, Any]:
info = dict(scan.scan_info)
info["scan_nb"] = self._sequence_root_scan_number(info)
return info
def _detector_name(self, entry: Any) -> str:
detector = str(
np.asarray(entry["technique"]["tomoconfig"]["detector"]).reshape(-1)[0]
)
if detector.endswith("_optic"):
return detector[: -len("_optic")]
return detector
def _detector_axes_metadata_name(self, entry: Any) -> str:
instrument = entry["instrument"]
active_config_names = np.asarray(
entry["technique"].get("active_tomo_config", [])
).reshape(-1)
for config_name in active_config_names:
config_entry = instrument.get(str(config_name))
if not isinstance(config_entry, dict):
continue
tomo_detector_names = np.asarray(
config_entry.get("tomo_detector", [])
).reshape(-1)
for tomo_detector_name in tomo_detector_names:
metadata_name = str(tomo_detector_name)
metadata_entry = instrument.get(metadata_name)
if isinstance(metadata_entry, dict) and "data_axes" in metadata_entry:
return metadata_name
raise KeyError(
"Could not resolve instrument/<tomo_detector>/data_axes from active_tomo_config"
)
def _configured_name(self, entry: Any, name: str) -> str:
aliases = np.asarray(entry["technique"]["tomoconfig"][name]).reshape(-1)
positioners_start = entry.get("positioners", {}).get("positioners_start", {})
for alias in aliases:
alias_name = str(alias)
if alias_name in positioners_start:
return alias_name
return str(aliases[0])
def _flat_on(self, entry: Any) -> list[int]:
return (
np.asarray(entry["technique"]["scan"]["flat_on"], dtype=int)
.reshape(-1)
.tolist()
)
def _return_ref_count(self, entry: Any, flat_on: list[int]) -> int:
aligned = bool(
entry["technique"]["scan_flags"]["return_images_aligned_to_flats"]
)
if aligned and flat_on:
return len(flat_on) + 1
scan_range = abs(float(entry["technique"]["scan"]["scan_range"]))
return int(min(scan_range, 360.0) / 90.0) + 1
def _is_projection_kind(self, kind: Optional[str]) -> bool:
return str(kind) in _PROJECTION_SUBSCAN_KINDS
def _subscan_plan(self, entry: Any) -> list[tuple[Optional[str], int]]:
subscans = entry["technique"]["subscans"]
dark_n = int(entry["technique"]["scan"]["dark_n"])
flat_n = int(entry["technique"]["scan"]["flat_n"])
tomo_n = int(entry["technique"]["scan"]["tomo_n"])
flat_on = self._flat_on(entry)
return_ref_n = self._return_ref_count(entry, flat_on)
projection_index, plan = 0, []
def sort_key(name: str) -> tuple[int, str]:
scan_suffix = name[4:] if name.startswith("scan") else name
return (int(scan_suffix) if scan_suffix.isdigit() else 0, name)
for name in sorted(subscans.keys(), key=sort_key):
kind = str(subscans[name]["type"])
if kind == "tomo:dark":
count = dark_n
elif kind == "tomo:flat":
count = flat_n
elif kind == "tomo:return_ref":
count = return_ref_n
elif self._is_projection_kind(kind) and flat_on:
count = int(flat_on[projection_index])
projection_index += 1
elif self._is_projection_kind(kind):
count = tomo_n
else:
raise ValueError(f"Unsupported subscan type for NXtomo: {kind!r}")
plan.append((kind, count))
return plan
def _expanded_subscan_plan(self, entry: Any) -> list[tuple[Optional[str], int]]:
return self._sequence_module(entry).expanded_subscan_plan(self, entry)
def _build_image_key_control(self, scan: BlissScanType) -> np.ndarray:
entry = self._entry(scan)
image_keys = []
for kind, count in self._expanded_subscan_plan(entry):
image_key = _SUBSCAN_IMAGE_KEYS.get(kind or "", 0)
image_keys.append(np.full(count, image_key, dtype=np.int64))
return np.concatenate(image_keys)
def _dtype_from_depth(self, depth: int) -> str:
if depth in {1, 2, 4, 8}:
depth *= 8
if depth <= 8:
return "uint8"
if depth <= 16:
return "uint16"
if depth <= 32:
return "uint32"
return "uint64"
def _detector_shape_dtype(self, entry: Any, detector: str) -> tuple[list[int], str]:
size = np.asarray(
entry["technique"]["detector"][detector]["size"], dtype=int
).reshape(-1)
depth = np.asarray(
entry["technique"]["detector"][detector]["depth"], dtype=int
).reshape(-1)[0]
dtype = self._dtype_from_depth(int(depth))
return [int(size[1]), int(size[0])], dtype
def _detector_path(self, detector: str) -> str:
return f"/entry_0000/instrument/{detector}/data"
def _images_root(self, scan: BlissScanType, scan_number: int, detector: str) -> str:
scan_number_value = scan.scan_saving.scan_number_format % int(scan_number)
return scan.scan_saving.images_path.format(
scan_number=scan_number_value,
img_acq_device=detector,
)
def _split_frame_count(self, frame_count: int, frames_per_file: int) -> list[int]:
if frames_per_file <= 0:
return [int(frame_count)]
remaining = int(frame_count)
splits = []
while remaining > 0:
current = min(int(frames_per_file), remaining)
splits.append(current)
remaining -= current
return splits
def _build_detector_virtual_sources(
self, scan: BlissScanType
) -> list[dict[str, Any]]:
records = self._subscan_records(scan, self._entry(scan))
return [source for record in records for source in record["sources"]]
def _subscan_records(self, scan: BlissScanType, entry: Any) -> list[dict[str, Any]]:
scan_number = int(entry["scan_nb"])
detector = self._detector_name(entry)
image_shape, dtype = self._detector_shape_dtype(entry, detector)
data_path = self._detector_path(detector)
frames_per_file = int(entry["technique"]["saving"]["frames_per_file"])
suffix = ".h5"
expanded_plan = self._expanded_subscan_plan(entry)
records = []
frame_start = 0
for index, (kind, frame_count) in enumerate(expanded_plan, start=1):
sources = []
splits = self._split_frame_count(frame_count, frames_per_file)
root = self._images_root(scan, scan_number + index, detector)
source_frame_start = frame_start
for file_index, file_frame_count in enumerate(splits):
sources.append(
{
"file_path": f"{root}{file_index:04d}{suffix}",
"data_path": data_path,
"shape": [int(file_frame_count), *image_shape],
"dtype": dtype,
"frame_start": int(source_frame_start),
"frame_stop": int(source_frame_start + file_frame_count),
}
)
source_frame_start += file_frame_count
records.append(
{
"kind": kind,
"frame_start": int(frame_start),
"frame_stop": int(frame_start + frame_count),
"sources": sources,
}
)
frame_start += frame_count
return records
def _slice_record(
self, record: dict[str, Any], start: int, stop: int
) -> dict[str, Any]:
if (
start < record["frame_start"]
or stop > record["frame_stop"]
or start >= stop
):
raise ValueError("Invalid frame slice for subscan record")
sources = []
for source in record["sources"]:
if source["frame_stop"] <= start or source["frame_start"] >= stop:
continue
if source["frame_start"] < start or source["frame_stop"] > stop:
raise ValueError(
"Cannot split turn-specific NXtomo because frames_per_file does not align with turn boundaries"
)
sources.append(source)
return {
"kind": record["kind"],
"frame_start": int(start),
"frame_stop": int(stop),
"sources": sources,
}
def _detector_pixel_size_um(
self, entry: Any, detector: str
) -> tuple[Optional[float], Optional[float]]:
technique = entry["technique"]
optic_info = technique.get("optic")
if isinstance(optic_info, dict) and "optics_pixel_size" in optic_info:
unit = str(optic_info.get("optics_pixel_size@units", "um")).strip()
values = np.asarray(optic_info["optics_pixel_size"], dtype=float).reshape(
-1
)
converted_values = [
self._convert_quantity(value, unit, "um", "optics_pixel_size")
for value in values
]
if len(converted_values) == 1:
converted_values *= 2
return float(converted_values[0]), float(converted_values[1])
detector_info = technique["detector"][detector]
unit = str(detector_info.get("pixel_size@units", "m")).strip()
values = np.asarray(detector_info["pixel_size"], dtype=float).reshape(-1)
converted_values = [
self._convert_quantity(value, unit, "um", "pixel_size") for value in values
]
return float(converted_values[0]), float(converted_values[1])
def _detector_data_axes(self, entry: Any) -> list[str]:
try:
metadata_name = self._detector_axes_metadata_name(entry)
axes = np.asarray(entry["instrument"][metadata_name]["data_axes"]).reshape(
-1
)
if axes.size != 2:
raise ValueError(
"instrument/<tomo_detector>/data_axes must contain 2 values"
)
return [
str(axis.decode() if isinstance(axis, bytes) else axis) for axis in axes
]
except Exception as exc:
_logger.warning(
"Could not resolve detector_data_axes from scan metadata (%s)", exc
)
return ["-z", "y"]
def _sample_pixel_size_um(self, entry: Any) -> float:
scan_info = entry["technique"]["scan"]
unit = str(scan_info.get("sample_pixel_size@units", "um")).strip()
return self._convert_quantity(
scan_info["sample_pixel_size"], unit, "um", "sample_pixel_size"
)
def _exposure_time_s(self, entry: Any) -> float:
scan_info = entry["technique"]["scan"]
exposure_time = float(scan_info["exposure_time"])
unit = str(scan_info.get("exposure_time@units", "s")).strip()
return self._convert_quantity(exposure_time, unit, "s", "exposure_time")
def _energy_keV(self, entry: Any) -> float:
scan_info = entry["technique"]["scan"]
energy = float(scan_info["energy"])
unit = str(scan_info.get("energy@units", "keV")).strip()
return self._convert_quantity(energy, unit, "keV", "energy")
def _position_array(self, entry: Any, alias_name: str) -> np.ndarray:
try:
positioners_start = entry["positioners"]["positioners_start"]
return np.asarray(
positioners_start[self._configured_name(entry, alias_name)], dtype=float
).reshape(-1)
except (KeyError, TypeError, ValueError) as exc:
_logger.warning(
"Could not resolve %s position from scan metadata, using 0.0 (%s)",
alias_name,
exc,
)
return np.asarray([0.0], dtype=float)
def _return_ref_angles(
self, entry: Any, start_angle: float, step_deg: float
) -> np.ndarray:
flat_on = self._flat_on(entry)
count = self._return_ref_count(entry, flat_on)
aligned = bool(
entry["technique"]["scan_flags"]["return_images_aligned_to_flats"]
)
if aligned and flat_on:
group_starts = [0]
total = 0
for frames in flat_on[:-1]:
total += int(frames)
group_starts.append(total)
group_starts.append(int(entry["technique"]["scan"]["tomo_n"]))
return start_angle + step_deg * np.asarray(group_starts[::-1], dtype=float)
end_angle = start_angle + np.sign(step_deg) * min(
abs(float(entry["technique"]["scan"]["scan_range"])), 360.0
)
return np.linspace(end_angle, start_angle, count, dtype=float)
def _default_synthetic_rotation(self, entry: Any) -> np.ndarray:
start_angle = float(self._position_array(entry, "rotation")[0])
step_deg = float(entry["technique"]["scan"]["scan_range"]) / float(
entry["technique"]["scan"]["tomo_n"]
)
if entry["technique"]["tomoconfig"].get("rotation_is_clockwise"):
step_deg = -abs(step_deg)
projection_index = 0
current_angle = start_angle
segments = []
for kind, frame_count in self._expanded_subscan_plan(entry):
if kind in {"tomo:dark", "tomo:flat"}:
segment = np.full(frame_count, current_angle, dtype=float)
elif kind == "tomo:return_ref":
segment = self._return_ref_angles(entry, start_angle, step_deg)
else:
segment = start_angle + step_deg * (
projection_index + np.arange(frame_count, dtype=float)
)
projection_index += frame_count
if frame_count:
current_angle = float(np.asarray(segment).reshape(-1)[-1])
segments.append(np.asarray(segment, dtype=float).reshape(-1))
return np.concatenate(segments) if segments else np.asarray([], dtype=float)
def _synthetic_rotation(self, entry: Any) -> np.ndarray:
module = self._sequence_module(entry)
if hasattr(module, "synthetic_rotation"):
return module.synthetic_rotation(self, entry)
return self._default_synthetic_rotation(entry)
def _synthetic_translation(self, entry: Any, alias_name: str) -> np.ndarray:
return self._sequence_module(entry).synthetic_translation(
self, entry, alias_name
)
[docs]
def estimate_CoR(self, entry: Any) -> float:
sample_pixel_size_um = self._sample_pixel_size_um(entry)
translation_y_mm = self._position_array(entry, "translation_y")
return calculate_relative_CoR_estimate(
pixel_size_mm=float(sample_pixel_size_um) / 1000.0,
translation_y_mm=float(translation_y_mm[0]),
offset_mm=float(self.offset_mm),
)
def _slice_per_frame_array(
self, value: Any, start: int, stop: int, total: int
) -> Any:
array = np.asarray(value)
if array.ndim == 0:
return value
if array.size == total:
return array[start:stop]
return value
def _segment_specs(
self, entry: Any, records: list[dict[str, Any]]
) -> list[dict[str, Any]]:
return self._sequence_module(entry).segment_specs(self, entry, records)
def _segment_definitions(
self, scan: BlissScanType, entry: Optional[Any] = None
) -> list[dict[str, Any]]:
entry = self._entry(scan) if entry is None else entry
records = self._subscan_records(scan, entry)
return self._segment_specs(entry, records)
def _record_index_for_piece(
self, records: list[dict[str, Any]], piece: dict[str, Any]
) -> int:
for index, record in enumerate(records):
if record["kind"] != piece["kind"]:
continue
if (
record["frame_start"] <= piece["frame_start"]
and piece["frame_stop"] <= record["frame_stop"]
):
return index
raise ValueError("Could not map segment piece to subscan record")
def _segment_completion_indices(
self, scan: BlissScanType, entry: Optional[Any] = None
) -> dict[Optional[str], int]:
entry = self._entry(scan) if entry is None else entry
records = self._subscan_records(scan, entry)
completion_indices = {}
for spec in self._segment_specs(entry, records):
completion_indices[spec["label"]] = max(
self._record_index_for_piece(records, piece) for piece in spec["pieces"]
)
return completion_indices
def _completed_labels(
self, scan: BlissScanType, entry: Optional[Any] = None
) -> list[Optional[str]]:
entry = self._entry(scan) if entry is None else entry
finished_index = scan.scan_info.get("index_in_sequence")
completion_indices = self._segment_completion_indices(scan, entry)
if finished_index is None:
return list(completion_indices)
finished_index = int(finished_index)
return [
label
for label, completion_index in completion_indices.items()
if completion_index == finished_index
]
def _build_segment_inputs(
self,
scan: BlissScanType,
entry: Any,
label: Optional[str],
pieces: list[dict[str, Any]],
full_inputs: dict[str, Any],
) -> list[dict[str, Any]]:
total = len(np.asarray(full_inputs["image_key_control"]).reshape(-1))
nframes = sum(piece["frame_stop"] - piece["frame_start"] for piece in pieces)
inputs = {}
for name, value in full_inputs.items():
if name in {
"detector_data_file_paths",
"detector_data_h5_url",
"detector_data_shapes",
"detector_data_dtype",
}:
segment_sources = [
source for piece in pieces for source in piece["sources"]
]
if name == "detector_data_file_paths":
inputs[name] = [
str(source["file_path"]) for source in segment_sources
]
elif name == "detector_data_h5_url":
inputs[name] = [
str(source["data_path"]) for source in segment_sources
]
elif name == "detector_data_shapes":
inputs[name] = [
[int(item) for item in source["shape"]]
for source in segment_sources
]
else:
inputs[name] = [str(source["dtype"]) for source in segment_sources]
elif name == "nx_path":
inputs[name] = self._build_output_path(
self._bliss_hdf5_path, label, entry
)
elif name == "sequence_number":
inputs[name] = np.arange(nframes, dtype=np.uint32)
else:
array = np.asarray(value)
if array.ndim == 0 or array.size != total:
inputs[name] = value
else:
inputs[name] = np.concatenate(
[
array[piece["frame_start"] : piece["frame_stop"]]
for piece in pieces
]
)
return [
{"task_identifier": _TASK_IDENTIFIER, "name": name, "value": value}
for name, value in inputs.items()
]
def _get_input_sets(self, scan: BlissScanType) -> list[dict[str, Any]]:
filename = scan.scan_info["filename"]
self._bliss_hdf5_path = str(filename)
entry = self._entry(scan)
self._output_path = self._build_output_path(self._bliss_hdf5_path, entry=entry)
records = self._subscan_records(scan, entry)
detector = self._detector_name(entry)
image_key_control = self._build_image_key_control(scan)
rotation_angle_deg = self._synthetic_rotation(entry)
nframes = int(image_key_control.size)
exposure_time = self._exposure_time_s(entry)
energy_keV = self._energy_keV(entry)
sample_pixel_size_um = self._sample_pixel_size_um(entry)
detector_x_pixel_size_um, detector_y_pixel_size_um = (
self._detector_pixel_size_um(entry, detector)
)
full_inputs = {
"nx_path": self._output_path,
"detector_data_file_paths": [],
"detector_data_h5_url": [],
"detector_data_shapes": [],
"detector_data_dtype": [],
"image_key_control": image_key_control,
"rotation_angle_deg": rotation_angle_deg,
"sample_name": str(scan.scan_saving.sample_name),
"title": str(entry["title"]),
"start_time": str(entry["start_time"]),
"estimated_cor": self.estimate_CoR(entry),
"detector_data_axes": self._detector_data_axes(entry),
"detector_x_pixel_size_um": detector_x_pixel_size_um,
"detector_y_pixel_size_um": detector_y_pixel_size_um,
"sample_x_pixel_size_um": sample_pixel_size_um,
"sample_y_pixel_size_um": sample_pixel_size_um,
"sample_detector_distance_mm": float(
entry["technique"]["scan"]["sample_detector_distance"]
),
"source_sample_distance_mm": float(
entry["technique"]["scan"]["source_sample_distance"]
),
"field_of_view": str(entry["technique"]["scan"]["field_of_view"]),
"instrument_name": str(scan.scan_saving.beamline),
"propagation_distance_mm": float(
entry["technique"]["scan"]["effective_propagation_distance"]
),
"energy_kev": energy_keV,
"count_time_s": np.full(nframes, exposure_time, dtype=float),
"y_translation_mm": self._synthetic_translation(entry, "translation_z"),
"z_translation_mm": self._synthetic_translation(entry, "sample_x"),
"x_translation_mm": self._synthetic_translation(entry, "sample_y"),
"sequence_number": np.arange(nframes, dtype=np.uint32),
}
return [
{
"label": spec["label"],
"inputs": self._build_segment_inputs(
scan, entry, spec["label"], spec["pieces"], full_inputs
),
}
for spec in self._segment_specs(entry, records)
]
def _get_inputs(self, scan: BlissScanType) -> list[dict[str, Any]]:
return self._get_input_sets(scan)[0]["inputs"]
def _get_submit_arguments(
self,
scan: BlissScanType,
input_set: dict[str, Any],
entry: Any,
) -> dict[str, Any]:
kwargs = {"inputs": input_set["inputs"], "outputs": [{"all": True}]}
upload_parameters = self._get_workflow_upload_parameters(
scan, input_set["label"], entry
)
if upload_parameters:
kwargs["upload_parameters"] = upload_parameters
return _to_json_serializable(kwargs)
def _get_workflow(self) -> dict[str, Any]:
with open(resource_filename("tomo", self.workflow), "r") as wf:
return json.load(wf)
[docs]
def execute_workflow(self, scan: BlissScanType) -> None:
if not self._is_tomo_sequence_scan(scan):
return
entry = self._entry(scan)
for input_set in self._get_input_sets(scan):
kwargs = self._get_submit_arguments(scan, input_set, entry)
kwargs["convert_destination"] = self.workflow_destination(
input_set["label"], entry
)
submit(args=(self._get_workflow(),), kwargs=kwargs, queue=self.queue)
def _trigger_workflow_on_new_scan(self, scan: BlissScanType) -> None:
self.execute_workflow(scan)