Source code for sdt.io.yaml

# SPDX-FileCopyrightText: 2020 Lukas Schrangl <lukas.schrangl@tuwien.ac.at>
#
# SPDX-License-Identifier: BSD-3-Clause

"""Support for loading and saving :py:mod:`sdt` data types from/to YAML

This includes the :py:class:`ArrayDumper` and :py:class:`SafeArrayDumper`
YAML dumpers for pretty-printing numpy arrays and their corresponding loaders,
:py:class:`ArrayLoader` and :py:class:`SafeArrayLoader`.

Additionally, there are the :py:class:`Loader`/:py:class:`SafeLoader`/
:py:class:`Dumper`/:py:class:`SafeDumper` subclasses that have
representers/constructors for many types in the :py:mod:`sdt` package.

All of them can be used by passing them as the `Dumper`/`Loader` parameters
to :py:func:`yaml.dump`/:py:func:`yaml.load`.
"""
import collections

import yaml
import numpy as np


class ArrayNode(yaml.Node):
    """YAML node for numpy arrays"""
    pass


class ArrayEvent(yaml.NodeEvent):
    """YAML parser event for numpy arrays"""
    def __init__(self, anchor, tag, value):
        super().__init__(anchor)
        self.tag = tag
        self.value = value
        self.implicit = False


class ArrayDumperBase:
    """Base class implementing a readable representation of numpy arrays

    Uses :py:func:`numpy.array2string` to create a readable representation
    of an array as a nested list.
    """
    array_tag = "!array"

    def represent_numpy_array(self, array):
        return ArrayNode(self.array_tag, array, None, None)

    def serialize_array_node(self, node):
        alias = self.anchors[node]
        if node in self.serialized_nodes:
            self.emit(yaml.AliasEvent(alias))
        else:
            self.serialized_nodes[node] = True
            self.emit(ArrayEvent(alias, node.tag, node.value))

    def expect_array_node(self):
        self.process_anchor("&")
        self.process_tag()
        self.increase_indent()
        with np.printoptions(threshold=np.inf):
            s = np.array2string(self.event.value, separator=", ")
        for l in s.split("\n"):
            self.write_indent()
            self.write_plain(l)
        self.indent = self.indents.pop()
        self.state = self.states.pop()


class ArrayDumper(yaml.Dumper, ArrayDumperBase):
    """A :py:class:`yaml.Dumper` making numpy array dumps more readable

    This uses :py:func:`numpy.array2string` to create a readable representation
    of an array as a nested list.

    To use this dumper, simply pass it to :py:func:`yaml.dump` as the `Dumper`
    parameter.
    """
    def serialize_node(self, node, parent, index):
        if isinstance(node, ArrayNode):
            self.serialize_array_node(node)
        else:
            super().serialize_node(node, parent, index)

    def expect_node(self, root=False, sequence=False, mapping=False,
                    simple_key=False):
        if isinstance(self.event, ArrayEvent):
            self.expect_array_node()
        else:
            super().expect_node(root, sequence, mapping, simple_key)


class SafeArrayDumper(yaml.SafeDumper, ArrayDumperBase):
    """A :py:class:`yaml.SafeDumper` making numpy array dumps more readable

    This uses :py:func:`numpy.array2string` to create a readable representation
    of an array as a nested list.

    To use this dumper, simply pass it to :py:func:`yaml.dump` as the `Dumper`
    parameter.
    """
    def serialize_node(self, node, parent, index):
        if isinstance(node, ArrayNode):
            self.serialize_array_node(node)
        else:
            super().serialize_node(node, parent, index)

    def expect_node(self, root=False, sequence=False, mapping=False,
                    simple_key=False):
        if isinstance(self.event, ArrayEvent):
            self.expect_array_node()
        else:
            super().expect_node(root, sequence, mapping, simple_key)


ArrayDumper.add_representer(np.ndarray, ArrayDumper.represent_numpy_array)
SafeArrayDumper.add_representer(np.ndarray,
                                SafeArrayDumper.represent_numpy_array)


class ArrayLoaderBase:
    """Base class for loading arrays dumped by :py:class:`ArrayDumperBase`"""
    def construct_numpy_array(self, data):
        a = self.construct_sequence(data, deep=True)
        return np.array(a)


class ArrayLoader(yaml.Loader, ArrayLoaderBase):
    """A :py:class:`yaml.Loader` for loading numpy arrays

    that were dumped using :py:class:`ArrayDumper`.

    To use this dumper, simply pass it to :py:func:`yaml.load` as the `Loader`
    parameter.
    """
    pass


class SafeArrayLoader(yaml.SafeLoader, ArrayLoaderBase):
    """A :py:class:`yaml.SafeLoader` for loading numpy arrays

    that were dumped using :py:class:`SafeArrayDumper`.

    To use this dumper, simply pass it to :py:func:`yaml.load` as the `Loader`
    parameter.
    """
    pass


ArrayLoader.add_constructor(ArrayDumperBase.array_tag,
                            ArrayLoader.construct_numpy_array)
SafeArrayLoader.add_constructor(ArrayDumperBase.array_tag,
                                SafeArrayLoader.construct_numpy_array)


[docs]class Dumper(ArrayDumper): """A :py:class:`ArrayDumper` with support for many more types of the :py:mod:`sdt` package, e. g. :py:class:`roi.ROI`. """ pass
[docs]class SafeDumper(SafeArrayDumper): """A :py:class:`SafeArrayDumper` with support for many more types of the :py:mod:`sdt` package, e. g. :py:class:`roi.ROI`. """ pass
[docs]class Loader(ArrayLoader): """A :py:class:`ArrayLoader` with support for many more types of the :py:mod:`sdt` package, e. g. :py:class:`roi.ROI`. """ pass
[docs]class SafeLoader(SafeArrayLoader): """A :py:class:`SafeArrayLoader` with support for many more types of the :py:mod:`sdt` package, e. g. :py:class:`roi.ROI`. """ pass
# slice def slice_representer(dumper, data): d = (("start", data.start), ("stop", data.stop), ("step", data.step)) return dumper.represent_mapping("!slice", d) def slice_constructor(loader, data): val = loader.construct_mapping(data) return slice(val["start"], val["stop"], val["step"]) # dict-like def dict_representer(dumper, data): return dumper.represent_mapping(yaml.resolver.Resolver.DEFAULT_MAPPING_TAG, ((k, v) for k, v in data.items())) # Load mappings as dicts, which preserve the order since Python >=3.7 def dict_constructor(loader, data): return dict(loader.construct_pairs(data)) Dumper.add_representer(slice, slice_representer) SafeDumper.add_representer(slice, slice_representer) # Use dict_representer on dicts since that preserves the order Dumper.add_representer(dict, dict_representer) SafeDumper.add_representer(dict, dict_representer) SafeDumper.add_representer(collections.OrderedDict, dict_representer) # Represent numpy scalars as standard types for D in (Dumper, SafeDumper): # There is no float128 on Windows; do it the safe way float_t = (getattr(np, t) for t in ("float16", "float32", "float64", "float128") if hasattr(np, t)) for t in float_t: D.add_representer(t, D.represent_float) for t in (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64): D.add_representer(t, D.represent_int) Loader.add_constructor("!slice", slice_constructor) SafeLoader.add_constructor("!slice", slice_constructor) Loader.add_constructor(yaml.resolver.Resolver.DEFAULT_MAPPING_TAG, dict_constructor) SafeLoader.add_constructor(yaml.resolver.Resolver.DEFAULT_MAPPING_TAG, dict_constructor) def _class_representer_factory(cls): """Create a representer for class `cls` Parameters ---------- cls : class Class to be represented using :py:meth:`yaml.Dumper.represent_yaml_object`. This class needs to have a `yaml_tag` attribute and may optionally have a `flow_style` attribute. Returns ------- callable Function to be passed to `yaml.Dumper.add_representer` """ flow_style = getattr(cls, "yaml_flow_style", None) def rep(dumper, data): return dumper.represent_yaml_object(cls.yaml_tag, data, cls, flow_style=flow_style) return rep def _class_constructor_factory(cls): """Create a constructor for class `cls` Parameters ---------- cls : class Class to be constructed using :py:meth:`yaml.Dumper.construct_yaml_object`. Returns ------- callable Function to be passed to `yaml.Dumper.add_constructor` """ def cons(loader, node): return loader.construct_yaml_object(node, cls) return cons
[docs]def dump(data, stream=None, Dumper=Dumper, **kwds): """Wrapper around :py:func:`yaml.dump` using :py:class:`Dumper`""" return yaml.dump(data, stream, Dumper, **kwds)
[docs]def safe_dump(data, stream=None, **kwds): """Wrap PyYAML's :py:func:`yaml.dump` using :py:class:`SafeDumper`""" return yaml.dump(data, stream, SafeDumper, **kwds)
[docs]def dump_all(documents, stream=None, Dumper=Dumper, **kwds): """Wrap PyYAML's :py:func:`yaml.dump_all` using :py:class:`Dumper`""" return yaml.dump_all(documents, stream, Dumper, **kwds)
[docs]def safe_dump_all(documents, stream=None, **kwds): """Wrap PyYAML's :py:func:`yaml.dump_all` using :py:class:`SafeDumper`""" return yaml.dump_all(documents, stream, SafeDumper, **kwds)
[docs]def load(stream, Loader=Loader): """Wrap PyYAML's :py:func:`yaml.load` using :py:class:`Loader`""" return yaml.load(stream, Loader)
[docs]def safe_load(stream): """Wrap PyYAML's :py:func:`yaml.load` using :py:class:`SafeLoader`""" return yaml.load(stream, SafeLoader)
[docs]def load_all(stream, Loader=Loader): """Wrap PyYAML's :py:func:`yaml.load_all` using :py:class:`Loader`""" return yaml.load_all(stream, Loader)
[docs]def safe_load_all(stream): """Wrap PyYAML's :py:func:`yaml.load_all` using :py:class:`SafeLoader`""" return yaml.load_all(stream, SafeLoader)
[docs]def register_yaml_class(cls): """Add support for representing and loading a class A representer is be added to :py:class:`Dumper` and :py:class:`SafeDumper`. A loader is added to :py:class:`Loader` and :py:class:`SafeLoader`. The class should have a `yaml_tag` attribute and may have a `yaml_flow_style` attribute. If `to_yaml` or `from_yaml` class methods exist, they will be used for representing or constructing class instances (see PyYAML's :py:func:`yaml.Dumper.add_representer` and :py:func:`yaml.Loader.add_constructor` details). Otherwise, the default :py:func:`Dumper.represent_yaml_object` and :py:func:`Loader.construct_yaml_object` are used. Parameters ---------- cls : type Class to add support for """ rep = getattr(cls, "to_yaml", None) rep = rep if callable(rep) else _class_representer_factory(cls) Dumper.add_representer(cls, rep) SafeDumper.add_representer(cls, rep) cons = getattr(cls, "from_yaml", None) cons = cons if callable(cons) else _class_representer_factory(cls) Loader.add_constructor(cls.yaml_tag, cons) SafeLoader.add_constructor(cls.yaml_tag, cons)