Source code for blissoda.persistent.parameters

import logging
import warnings
from collections.abc import MutableMapping
from dataclasses import dataclass
from itertools import zip_longest
from pprint import pformat
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Mapping
from typing import NamedTuple
from typing import Optional
from typing import Union

import numpy

from ..bliss_globals import current_session
from ..import_utils import is_available
from ..utils.deprecate import WithDeprecatedClassAttributes

try:
    from bliss.common.utils import autocomplete_property
except ImportError:
    autocomplete_property = property

try:
    from bliss.config.settings import HashObjSetting
except ImportError:
    try:
        from blissdata.settings import HashObjSetting
    except ImportError as ex:
        from .mock_redis import LocalHashObjSetting as HashObjSetting

        HashObjSetting._LOCAL_REASON = ex


logger = logging.getLogger(__name__)


[docs] @dataclass class ParameterInfo: name: str category: str = "parameters" hidden: Optional[bool] = None doc: Optional[str] = None deprecated_names: Optional[List[str]] = None validator: Optional[Callable] = None def __post_init__(self): if self.hidden is None: self.hidden = self.name.startswith("_")
[docs] class ParameterValue(NamedTuple): value: Any doc: Optional[str] = None
[docs] class WithPersistentParameters(WithDeprecatedClassAttributes): """Adds parameters as properties that will be stored in Redis .. code-block:: python class MyClass(WithPersistentParameters, parameters=["a", "b"]) pass myobj = MyClass() myobj.a = 10 myobj.b = None # remove """ _PARAMETERS: Dict[str, ParameterInfo] = dict() _HAS_BLISS: bool = is_available(current_session) def __init__( self, shared: bool = False, name: Optional[str] = None, defaults: Optional[Dict[str, Any]] = None, **deprecated_defaults: Dict[str, Any], ) -> None: r""" Initialize a `WithPersistentParameters` instance with Redis-backed storage. The Redis key used for storage depends on `name` and `shared`: +---------------------------------------------+--------------+-----------+ | Redis key | Shared | Singleton | +=============================================+==============+===========+ | blissoda:{class_name} | yes | yes | | blissoda:{class_name}:{name} | yes | no | | blissoda:{session_name}:{class_name} (*) | no | yes | | blissoda:{session_name}:{class_name}:{name} | no | no | +---------------------------------------------+--------------+-----------+ (*): the default :param shared: If True, parameters are shared between Bliss sessions. :param name: If None, the parameters are a singleton in the scope of a single Bliss session (``shared=False``) or globally (``shared=True``). :param defaults: Initial values for parameters stored in Redis. :param \**deprecated_defaults: Deprecated way of passing defaults (merged with `defaults` if provided). """ defaults = self._merge_defaults(deprecated_defaults, defaults) # Ensure Redis object exists self._class_name = self.__class__.__name__ self._session_name = current_session.name if self._HAS_BLISS else "nosession" self._redis_key = self._get_redis_key(name, shared) self._parameters = HashObjSetting(self._redis_key) # Initialize Redis object with defaults self._init_parameters(defaults) def __repr__(self) -> str: return ( f"<{self.__class__.__module__}.{self.__class__.__name__} " f"object at {hex(id(self))} REDIS KEY={self._redis_key!r}>" ) def _get_redis_key(self, name: Optional[str], shared: bool) -> str: if shared: # Global scope if name: # instance in the global scope return f"blissoda:{self._class_name}:{name}" else: # singleton in the global scope return f"blissoda:{self._class_name}" else: # Session scope if name: # instance in the session scope return f"blissoda:{self._session_name}:{self._class_name}:{name}" else: # singleton in the session scope (DEFAULT) return f"blissoda:{self._session_name}:{self._class_name}"
[docs] def copy_and_remove_parameters( self, shared: bool, name: Optional[str], ) -> None: """Copy&Remove parameters from another scope (global or session, singleton or instance).""" redis_key = self._get_redis_key(name, shared) if self._redis_key == redis_key: logger.warning("Same Redis key %r. Nothing to move.", redis_key) return if not self._parameters.connection.exists(redis_key): logger.warning("Redis key %r does not exist. Nothing to move.", redis_key) return obsolete = HashObjSetting(redis_key) parameters = {**self._parameters.get_all(), **obsolete.get_all()} for name, value in parameters.items(): self._set_parameter(name, value) obsolete.clear() logger.warning( "The Redis key %r has been replaced by %r", redis_key, self._redis_key, )
@classmethod def _merge_defaults( cls, deprecated_defaults: Dict[str, Any], defaults: Optional[Dict[str, Any]], ) -> Dict[str, Any]: if deprecated_defaults: warning = f"{cls.__name__}(**defaults) is deprecated, use {cls.__name__}(defaults={{...}})" logger.warning(warning) warnings.warn(warning, DeprecationWarning, stacklevel=2) if defaults: return {**deprecated_defaults, **defaults} else: return deprecated_defaults def _init_parameters(self, defaults: dict) -> None: for name, param in self._PARAMETERS.items(): self._remove_deprecated_parameters(name, param.deprecated_names) for name, value in defaults.items(): if self._get_parameter(name) is None: self._set_parameter(name, value) def __init_subclass__( cls, parameters: Optional[List[Union[str, Mapping, ParameterInfo]]] = None, **kw, ) -> None: super().__init_subclass__(**kw) parameters = parameters or list() new_parameters = list() for p in parameters: if isinstance(p, str): p = ParameterInfo(name=p) elif isinstance(p, Mapping): p = ParameterInfo(**p) new_parameters.append(p) new_parameters = {p.name: p for p in new_parameters} cls._PARAMETERS = {**cls._PARAMETERS, **new_parameters} for name, param in new_parameters.items(): _add_parameter_property(cls, name, param.validator) for deprecated_name in param.deprecated_names or list(): _add_deprecated_parameter_property( cls, deprecated_name, name, param.validator ) def __dir__(self) -> Iterable[str]: return list(super().__dir__()) + [ name for name, param in self._PARAMETERS.items() if not param.hidden ] def _info_categories(self) -> Dict[str, dict]: categories = dict() all_values = self._parameters.get_all() for name, param in self._PARAMETERS.items(): if param.hidden: continue category = categories.setdefault(param.category, dict()) value = all_values.get(name, None) category[name] = ParameterValue(value=value, doc=param.doc) return categories def __info__(self) -> str: return "\n" + "\n\n".join( [ f"{name.capitalize()}:\n {_format_info_category(category)}" for name, category in self._info_categories().items() if category ] ) def _get_parameter(self, name: str): v = self._parameters.get(name) if isinstance(v, dict): return RedisDictWrapper(name, self._parameters.get, self._set_parameter) return v def _set_parameter(self, name: str, value): if isinstance(value, RemoteDictWrapper): value = value.to_dict() self._parameters[name] = value def _del_parameter(self, name: str): self._parameters[name] = None def _remove_deprecated_parameters( self, name: str, deprecated_names: Optional[List[str]] ) -> None: if not deprecated_names: return value = self._parameters.get(name) for deprecated_name in deprecated_names: deprecated_value = self._parameters.get(deprecated_name) if deprecated_value is not None: self._del_parameter(deprecated_name) if value is None: value = deprecated_value self._set_parameter(name, value) def _raise_when_missing(self, *names): for name in names: if self._get_parameter(name) is None: raise AttributeError(f"parameter '{name}' is not set")
def _add_parameter_property(cls, name: str, validator: Optional[Callable]) -> None: """Add a property to a `WithPersistentParameters` instance which sets and gets a persistent parameter. """ if hasattr(cls, name): return def getter(self): return self._get_parameter(name) method = autocomplete_property(getter) setattr(cls, name, method) if validator is None: def setter(self, value): self._set_parameter(name, value) elif not callable(validator): raise TypeError(f"Validator for {name} is not callable") else: def setter(self, value): validated_value = validator(value) self._set_parameter(name, validated_value) method = getattr(cls, name).setter(setter) setattr(cls, name, method) def _add_deprecated_parameter_property( cls, old_name: str, new_name: str, validator: Optional[Union[type, Callable]] ) -> None: """Add a property to a `WithPersistentParameters` instance which sets and gets a persistent parameter with a deprecated name. """ if hasattr(cls, old_name): return def getter(self): logger.warning( f"'{old_name}' is deprecated and will be removed in a future version. Use '{new_name}' instead." ) return self._get_parameter(new_name) method = autocomplete_property(getter) setattr(cls, old_name, method) if validator is None: def setter(self, value): logger.warning( f"'{old_name}' is deprecated and will be removed in a future version. Use '{new_name}' instead." ) self._set_parameter(new_name, value) elif not callable(validator): raise TypeError(f"Validator for {old_name} is not callable") else: def setter(self, value): logger.warning( f"'{old_name}' is deprecated and will be removed in a future version. Use '{new_name}' instead." ) validated_value = validator(value) self._set_parameter(new_name, validated_value) method = getattr(cls, old_name).setter(setter) setattr(cls, old_name, method)
[docs] class RemoteDictWrapper(MutableMapping): """Whenever you get, set or delete the value, the entire dictionary is pushed/pull from a remote source""" def __dir__(self) -> Iterable[str]: return list(super().__dir__()) + list(self) def _get_all(self) -> dict: raise NotImplementedError def _set_all(self, value: Mapping) -> dict: raise NotImplementedError
[docs] def to_dict(self) -> dict: return self._get_all()
def __str__(self): return str(self._get_all()) def __repr__(self): return repr(self._get_all()) def __getitem__(self, key: str) -> Any: value = self._get_all()[key] if isinstance(value, dict): value = MemoryDictWrapper(self, key) return value def __setitem__(self, key: str, value: Any) -> None: adict = self._get_all() if isinstance(value, RemoteDictWrapper): value = value.to_dict() adict[key] = value return self._set_all(adict) def __delitem__(self, key: str) -> None: adict = self._get_all() del adict[key] return self._set_all(adict) def __iter__(self) -> Iterator[Any]: return self._get_all().__iter__() def __len__(self) -> int: return self._get_all().__len__()
[docs] class RedisDictWrapper(RemoteDictWrapper): def __init__(self, name: str, getter: Callable, setter: Callable) -> None: self._name = name self._getter = getter self._setter = setter def __getattr__(self, name): try: return self[name] except KeyError: raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) def __setattr__(self, name: str, value: Any) -> None: if name in ("_name", "_getter", "_setter"): return super().__setattr__(name, value) self[name] = value def _get_all(self) -> dict: adict = self._getter(self._name) if adict is None: return dict() return adict def _set_all(self, value: Mapping) -> None: self._setter(self._name, value)
[docs] class MemoryDictWrapper(RemoteDictWrapper): def __init__(self, parent: RemoteDictWrapper, name: str): self._parent = parent self._name = name def __getattr__(self, name): try: return self[name] except KeyError: raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) def __setattr__(self, name: str, value: Any) -> None: if name in ("_name", "_parent"): return super().__setattr__(name, value) self[name] = value def _get_all(self) -> dict: return self._parent._get_all()[self._name] def _set_all(self, value: Mapping) -> None: self._parent[self._name] = value
def _format_info_category(category: Dict) -> str: if not category: return "" rows: List[List[List[str]]] = list() for name, pvalue in category.items(): ldoc = [] if isinstance(pvalue, ParameterValue): value = pvalue.value if pvalue.doc: ldoc = str(pvalue.doc or "") if "\n" in ldoc: ldoc = ldoc.split("\n") else: ldoc = [ldoc[i : i + 60] for i in range(0, len(ldoc), 60)] else: value = pvalue lvalue = pformat(value, width=60).split("\n") rows.append([[str(name)], lvalue, ldoc]) lengths = numpy.array( [[max(len(s) for s in lines) if lines else 0 for lines in row] for row in rows] ) fmt = " ".join(["{{:<{}}}".format(n) for n in lengths.max(axis=0)]) lines = [ fmt.format(*svalues) for row in rows for svalues in zip_longest(*row, fillvalue="") ] return "\n ".join(lines)