Source code for blissoda.persistent.parameters

from pprint import pformat
from dataclasses import dataclass
from collections.abc import MutableMapping
from typing import Callable, Any, Iterator, Mapping, Iterable, Dict, Optional

import numpy

try:
    from bliss import current_session
except ImportError:
    current_session = None
try:
    from bliss.common.utils import autocomplete_property
except ImportError:
    autocomplete_property = property
try:
    from bliss.config.settings import HashObjSetting
except ImportError:
    try:
        from blissdata.settings import HashObjSetting
    except ImportError:
        HashObjSetting = None


[docs] @dataclass class ParameterInfo: name: str category: str = "parameters" hidden: Optional[bool] = None def __post_init__(self): if self.hidden is None: self.hidden = self.name.startswith("_")
[docs] class WithPersistentParameters: """Adds parameters as properties that will be stored in Redis .. code:: python class MyClass(WithPersistentParameters, parameters=["a", "b"]) pass myobj = MyClass() myobj.a = 10 myobj.b = None # remove """ _PARAMETERS: Dict[str, ParameterInfo] = dict() def __init__(self, **defaults) -> None: if current_session is None: raise ModuleNotFoundError("No module named 'bliss'") self._parameters = HashObjSetting( f"blissoda:{current_session.name}:{self.__class__.__name__}" ) for name, value in defaults.items(): if self._get_parameter(name) is None: self._set_parameter(name, value) def __init_subclass__(cls, parameters=None, **kw) -> None: super().__init_subclass__(**kw) if parameters: extra_parameters = list() for p in parameters: if isinstance(p, str): p = ParameterInfo(name=p) elif isinstance(p, Mapping): p = ParameterInfo(**p) extra_parameters.append(p) extra_parameters = {p.name: p for p in extra_parameters} else: extra_parameters = dict() cls._PARAMETERS = {**cls._PARAMETERS, **extra_parameters} for name in extra_parameters: add_parameter_property(cls, name) def __dir__(self) -> Iterable[str]: return list(super().__dir__()) + [ name for name, param in self._PARAMETERS.items() if not param.hidden ] def _info_categories(self) -> Dict[str, dict]: categories = dict() all_values = self._parameters.get_all() for name, param in self._PARAMETERS.items(): if param.hidden: continue category = categories.setdefault(param.category, dict()) category[name] = all_values.get(name) return categories def __info__(self) -> str: return "\n\n".join( [ f"{name.capitalize()}:\n {_format_info(category)}" for name, category in self._info_categories().items() if category ] ) def _get_parameter(self, name): v = self._parameters.get(name) if isinstance(v, dict): return RedisDictWrapper(name, self._parameters.get, self._set_parameter) return v def _set_parameter(self, name, value): if isinstance(value, RemoteDictWrapper): value = value.to_dict() self._parameters[name] = value def _del_parameter(self, name): self._parameters[name] = None def _raise_when_missing(self, *names): for name in names: if self._get_parameter(name) is None: raise AttributeError(f"parameter '{name}' is not set")
[docs] def add_parameter_property(cls, name): if hasattr(cls, name): return method = autocomplete_property(lambda self: self._get_parameter(name)) setattr(cls, name, method) method = getattr(cls, name).setter( lambda self, value: self._set_parameter(name, value) ) setattr(cls, name, method)
[docs] class RemoteDictWrapper(MutableMapping): """Whenever you get, set or delete the value, the entire dictionary is pushed/pull from a remote source""" def __dir__(self) -> Iterable[str]: return list(super().__dir__()) + list(self) def _get_all(self) -> dict: raise NotImplementedError def _set_all(self, value: Mapping) -> dict: raise NotImplementedError
[docs] def to_dict(self) -> dict: return self._get_all()
def __str__(self): return str(self._get_all()) def __repr__(self): return repr(self._get_all()) def __getitem__(self, key: str) -> Any: value = self._get_all()[key] if isinstance(value, dict): value = MemoryDictWrapper(self, key) return value def __setitem__(self, key: str, value: Any) -> None: adict = self._get_all() if isinstance(value, RemoteDictWrapper): value = value.to_dict() adict[key] = value return self._set_all(adict) def __delitem__(self, key: str) -> None: adict = self._get_all() del adict[key] return self._set_all(adict) def __iter__(self) -> Iterator[Any]: return self._get_all().__iter__() def __len__(self) -> int: return self._get_all().__len__()
[docs] class RedisDictWrapper(RemoteDictWrapper): def __init__(self, name: str, getter: Callable, setter: Callable) -> None: self._name = name self._getter = getter self._setter = setter def __getattr__(self, name): try: return self[name] except KeyError: raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) def __setattr__(self, name: str, value: Any) -> None: if name in ("_name", "_getter", "_setter"): return super().__setattr__(name, value) self[name] = value def _get_all(self) -> dict: adict = self._getter(self._name) if adict is None: return dict() return adict def _set_all(self, value: Mapping) -> None: self._setter(self._name, value)
[docs] class MemoryDictWrapper(RemoteDictWrapper): def __init__(self, parent: RemoteDictWrapper, name: str): self._parent = parent self._name = name def __getattr__(self, name): try: return self[name] except KeyError: raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) def __setattr__(self, name: str, value: Any) -> None: if name in ("_name", "_parent"): return super().__setattr__(name, value) self[name] = value def _get_all(self) -> dict: return self._parent._get_all()[self._name] def _set_all(self, value: Mapping) -> None: self._parent[self._name] = value
def _format_info(info: Dict) -> str: if not info: return "" rows = [ (str(name), pformat(value, width=60).split("\n")) for name, value in info.items() ] lengths = numpy.array( [[len(name), max(len(s) for s in value)] for name, value in rows] ) fmt = " ".join(["{{:<{}}}".format(n) for n in lengths.max(axis=0)]) lines = list() for name, value in rows: for i, s in enumerate(value): if i == 0: lines.append(fmt.format(name, s)) else: lines.append(fmt.format("", s)) return "\n ".join(lines)