Source code for sdt.helper.slicerator

# Copyright (c) 2015, Daniel B. Allan
# SPDX-FileCopyrightText: 2020 Lukas Schrangl <lukas.schrangl@tuwien.ac.at>
#
# SPDX-License-Identifier: BSD-3-Clause
#
# Based on https://github.com/soft-matter/slicerator

"""A lazy-loading, fancy-slicable iterable

forked from https://github.com/soft-matter/slicerator (originally released
under MIT license).
"""
import collections
import itertools
from functools import wraps
from copy import copy
import inspect
from contextlib import suppress


def _iter_attr(obj):
    try:
        for ns in [obj] + obj.__class__.mro():
            for attr in ns.__dict__:
                yield ns.__dict__[attr]
    except AttributeError:
        return  # obj has no __dict__


[docs]class Slicerator: """A generator that supports fancy indexing When sliced using any iterable with a known length, it returns another object like itself, a Slicerator. When sliced with an integer, it returns the data payload. Also, the attributes of the parent object can be propagated, exposed through the child Slicerators. By default, no attributes are propagated. Attributes can be white-listed by using the optional parameter `propagated_attrs`. Methods taking an index will be remapped if they are decorated with `index_attr`. They also have to be present in the `propagate_attrs` list. """ _slicerator_flag = True def __init__(self, ancestor, indices=None, length=None, propagate_attrs=None): """Parameters ---------- ancestor : object indices : iterable Giving indices into `ancestor`. Required if len(ancestor) is invalid. length : integer length of indices This is required if `indices` is a generator, that is, if `len(indices)` is invalid propagate_attrs : list of str, optional list of attributes to be propagated into Slicerator Examples -------- Slicing on a Slicerator returns another Slicerator: >>> v = Slicerator([0, 1, 2, 3], range(4), 4) >>> v1 = v[:2] >>> type(v[:2]) Slicerator >>> v2 = v[::2] >>> type(v2) Slicerator >>> v2[0] 0 Unless the slice itself has an unknown length, which makes slicing impossible: >>> v3 = v2((i for i in [0])) # argument is a generator >>> type(v3) generator """ if indices is None and length is None: try: length = len(ancestor) indices = list(range(length)) except TypeError: raise ValueError("The length parameter is required in this " "case because len(ancestor) is not valid.") elif indices is None: indices = list(range(length)) elif length is None: try: length = len(indices) except TypeError: raise ValueError("The length parameter is required in this " "case because len(indices) is not valid.") # when list of propagated attributes are given explicitly, # take this list and ignore the class definition if propagate_attrs is not None: self._propagate_attrs = propagate_attrs else: # check propagated_attrs field from the ancestor definition self._propagate_attrs = [] if hasattr(ancestor, '_propagate_attrs'): self._propagate_attrs += ancestor._propagate_attrs if hasattr(ancestor, 'propagate_attrs'): self._propagate_attrs += ancestor.propagate_attrs # add methods having the _propagate flag for attr in _iter_attr(ancestor): if hasattr(attr, '_propagate_flag'): self._propagate_attrs.append(attr.__name__) self._len = length self._ancestor = ancestor self._indices = indices
[docs] @classmethod def from_func(cls, func, length, propagate_attrs=None): """ Make a Slicerator from a function that accepts an integer index Parameters ---------- func : callable callable that accepts an integer as its argument length : int number of elements; used to supposed revserse slicing like [-1] propagate_attrs : list, optional list of attributes to be propagated into Slicerator """ class Dummy: def __getitem__(self, i): return func(i) def __len__(self): return length return cls(Dummy(), propagate_attrs=propagate_attrs)
[docs] @classmethod def from_class(cls, some_class, propagate_attrs=None): """Make an existing class support fancy indexing via Slicerator objects When sliced using any iterable with a known length, it returns a Slicerator. When sliced with an integer, it returns the data payload. Also, the attributes of the parent object can be propagated, exposed through the child Slicerators. By default, no attributes are propagated. Attributes can be white_listed in the following ways: 1. using the optional parameter `propagate_attrs`; the contents of this list will overwrite any other list of propagated attributes 2. using the @propagate_attr decorator inside the class definition 3. using a `propagate_attrs` class attribute inside the class definition The difference between options 2 and 3 appears when subclassing. As option 2 is bound to the method, the method will always be propagated. On the contrary, option 3 is bound to the class, so this can be overwritten by the subclass. Methods taking an index will be remapped if they are decorated with `index_attr`. This decorator does not ensure that the method is propagated. The existing class should support indexing (:py:meth:`__getitem__` method) and it should define a length (:py:meth:`__len__`). The result will look exactly like the existing class (:py:attr:`__name__`, :py:attr:`__doc__`, :py:attr:`__module__`, :py:meth:`__repr__` will be propagated), but :py:meth:`__getitem__` will be renamed to :py:meth:`_get` and :py:meth:`__getitem__` will produce a :py:class:`Slicerator` object when sliced. Parameters ---------- some_class : type propagated_attrs : list, optional list of attributes to be propagated into Slicerator this will overwrite any other propagation list """ class SliceratorSubclass(some_class): _slicerator_flag = True _get = some_class.__getitem__ if hasattr(some_class, '__doc__'): __doc__ = some_class.__doc__ # for Python 2, do it here def __getitem__(self, i): """Getitem supports repeated slicing via Slicerator objects.""" indices, new_length = key_to_indices(i, len(self)) if new_length is None: return self._get(indices) else: return cls(self, indices, new_length, propagate_attrs) for name in ['__name__', '__module__', '__repr__']: try: setattr(SliceratorSubclass, name, getattr(some_class, name)) except AttributeError: pass return SliceratorSubclass
@property def indices(self): # Advancing indices won't affect this new copy of self._indices. indices, self._indices = itertools.tee(iter(self._indices)) return indices def _get(self, key): return self._ancestor[key] def _map_index(self, key): if key < -self._len or key >= self._len: raise IndexError("Key out of range") try: abs_key = self._indices[key] except TypeError: key = key if key >= 0 else self._len + key for _, i in zip(range(key + 1), self.indices): abs_key = i return abs_key def __repr__(self): msg = "Sliced {0}. Original repr:\n".format( type(self._ancestor).__name__) old = '\n'.join(" " + ln for ln in repr(self._ancestor).split('\n')) return msg + old def __iter__(self): return (self._get(i) for i in self.indices) def __len__(self): return self._len def __getitem__(self, key): """for data access""" if not (isinstance(key, slice) or isinstance(key, collections.abc.Iterable)): return self._get(self._map_index(key)) else: rel_indices, new_length = key_to_indices(key, len(self)) if new_length is None: return (self[k] for k in rel_indices) indices = _index_generator(rel_indices, self.indices) return Slicerator(self._ancestor, indices, new_length, self._propagate_attrs) def __getattr__(self, name): # to avoid infinite recursion, always check if public field is there if '_propagate_attrs' not in self.__dict__: self._propagate_attrs = [] if name in self._propagate_attrs: attr = getattr(self._ancestor, name) if (isinstance(attr, SliceableAttribute) or hasattr(attr, '_index_flag')): return SliceableAttribute(self, attr) else: return attr raise AttributeError def __getstate__(self): # When serializing, return a list of the sliced data # Any exposed attrs are lost. return list(self) def __setstate__(self, data_as_list): # When deserializing, restore a Slicerator instance return self.__init__(data_as_list)
def key_to_indices(key, length): """Converts a fancy key into a list of indices. Parameters ---------- key : slice, iterable of numbers, or boolean mask length : integer length of object that will be indexed Returns ------- indices, new_length """ if isinstance(key, slice): # if we have a slice, return a range object returning the indices start, stop, step = key.indices(length) indices = range(start, stop, step) return indices, len(indices) if isinstance(key, collections.abc.Iterable): # if the input is an iterable, doing 'fancy' indexing if hasattr(key, '__array__') and hasattr(key, 'dtype'): if key.dtype == bool: # if we have a bool array, set up masking and return indices nums = range(length) # This next line fakes up numpy's bool masking without # importing numpy. indices = [x for x, y in zip(nums, key) if y] return indices, sum(key) try: new_length = len(key) except TypeError: # The key is a generator; return a plain old generator. # Withoug using the generator, we cannot know its length. # Also it cannot be checked if values are in range. gen = ((_k if _k >= 0 else length + _k) for _k in key) return gen, None else: # The key is a list of in-range values. Check if they are in range. if any(_k < -length or _k >= length for _k in key): raise IndexError("Keys out of range") rel_indices = ((_k if _k >= 0 else length + _k) for _k in key) return rel_indices, new_length # other cases: it's possibly a number try: key = int(key) except TypeError: pass else: # allow negative indexing if -length < key < 0: return length + key, None elif 0 <= key < length: return key, None else: raise IndexError('index out of range') # in all other case, just return the key and let user deal with the type. return key, None def _index_generator(new_indices, old_indices): """Find locations of new_indicies in the ref. frame of the old_indices. Example: (1, 3), (1, 3, 5, 10) -> (3, 10) The point of all this trouble is that this is done lazily, returning a generator without actually looping through the inputs.""" # Use iter() to be safe. On a generator, this returns an identical ref. new_indices = iter(new_indices) try: n = next(new_indices) except StopIteration: # new_indices is empty return last_n = None done = False while True: old_indices_, old_indices = itertools.tee(iter(old_indices)) for i, o in enumerate(old_indices_): # If new_indices is not strictly monotonically increasing, break # and start again from the beginning of old_indices. if last_n is not None and n <= last_n: last_n = None break if done: return if i == n: last_n = n try: n = next(new_indices) except StopIteration: done = True # Don't stop yet; we still have one last thing to yield. yield o else: continue
[docs]class Pipeline: """A class to support lazy function evaluation on an iterable. When a :py:class:`Pipeline` object is indexed, it returns an element of its ancestor modified with a process function. """ _slicerator_flag = True def __init__(self, proc_func, *ancestors, propagate_attrs=None, propagate_how="first"): """Parameters ---------- proc_func : callable function that processes data returned by Slicerator. The function acts element-wise and is only evaluated when data is actually returned *ancestors : objects Object to be processed. propagate_attrs : set of str or None, optional Names of attributes to be propagated through the pipeline. If this is `None`, go through ancestors and look at `_propagate_attrs` and `propagate_attrs` attributes and search for attributes having a `_propagate_flag` attribute. Defaults to `None`. propagate_how : {'first', 'last'} or int, optional Where to look for attributes to propagate. If this is an integer, it specifies the index of the ancestor (in `ancestors`). If it is 'first', go through all ancestors starting with the first one until one is found that has the attribute. If it is 'last', go through the ancestors in reverse order. Defaults to 'first'. Example ------- Construct the pipeline object that multiplies elements by two: >>> ancestor = [0, 1, 2, 3, 4] >>> times_two = Pipeline(lambda x: 2*x, ancestor) Whenever the pipeline object is indexed, it takes the correct element from its ancestor, and then applies the process function. >>> times_two[3] 6 See also -------- pipeline """ # Only accept ancestors of the same length self._len = len(ancestors[0]) if not all(len(a) == self._len for a in ancestors): raise ValueError('Ancestors have to be of same length.') self._ancestors = ancestors self._proc_func = proc_func self._propagate_how = propagate_how # when list of propagated attributes are given explicitly, # take this list and ignore the class definition if propagate_attrs is not None: self._propagate_attrs = set(propagate_attrs) else: # check propagated_attrs field from the ancestor definition self._propagate_attrs = set() for a in self._get_prop_ancestors(): if hasattr(a, '_propagate_attrs'): self._propagate_attrs.update(a._propagate_attrs) if hasattr(a, 'propagate_attrs'): self._propagate_attrs.update(a.propagate_attrs) # add methods having the _propagate flag for attr in _iter_attr(a): if hasattr(attr, '_propagate_flag'): self._propagate_attrs.add(attr.__name__) def _get_prop_ancestors(self): """Get relevant ancestor(s) for attribute propagation Returns ------- list List of ancestors. """ if isinstance(self._propagate_how, int): return self._ancestors[self._propagate_how:self._propagate_how+1] if self._propagate_how == 'first': return self._ancestors if self._propagate_how == 'last': return self._ancestors[::-1] raise ValueError("propagate_how has to be an index, 'first', or " "'last'.") def _get(self, key): # We need to copy here: else any _proc_func that acts inplace would # change the ancestor value. return self._proc_func(*(copy(a[key]) for a in self._ancestors)) def __repr__(self): anc_str = ", ".join(type(a).__name__ for a in self._ancestors) msg = '({0},) processed through {1}. Original repr:\n '.format( anc_str, self._proc_func.__name__) old = [repr(a).replace('\n', '\n ') for a in self._ancestors] return msg + "\n ----\n ".join(old) def __len__(self): return self._len def __iter__(self): return (self._get(i) for i in range(len(self))) def __getitem__(self, i): """for data access""" indices, new_length = key_to_indices(i, len(self)) if new_length is None: return self._get(indices) else: return Slicerator(self, indices, new_length, self._propagate_attrs) def __getattr__(self, name): # to avoid infinite recursion, always check if public field is there pa = self.__dict__.get('_propagate_attrs', []) if not isinstance(pa, collections.abc.Iterable): raise TypeError('_propagate_attrs is not iterable') if name in pa: for a in self._get_prop_ancestors(): try: return getattr(a, name) except AttributeError: pass raise AttributeError('No attribute `{}` propagated.'.format(name)) def __getstate__(self): # When serializing, return a list of the processed data # Any exposed attrs are lost. return list(self) def __setstate__(self, data_as_list): # When deserializing, restore the Pipeline return self.__init__(lambda x: x, data_as_list)
_pipeline_types = (Slicerator, Pipeline) with suppress(ImportError): # Also support the pipeline decorator for the original slicerator's # `Pipeline` and `Slicerator` classes import slicerator as slc _pipeline_types = _pipeline_types + (slc.Slicerator, slc.Pipeline) del slc
[docs]def pipeline(func=None, **kwargs): """Decorator to enable lazy evaluation of a function. When the function is applied to a Slicerator or Pipeline object, it returns another lazily-evaluated, Pipeline object. When the function is applied to any other object, it falls back on its normal behavior. Parameters ---------- func : callable or type Function or class type for lazy evaluation retain_doc : bool, optional If True, don't modify `func`'s doc string to say that it has been made lazy. Defaults to False ancestor_count : int or 'all', optional Number of inputs to the pipeline. For instance, a function taking three parameters that adds up the elements of two :py:class:`Slicerators` and a constant offset would have ``ancestor_count=2``. If 'all', all the function's arguments are used for the pipeline. Defaults to 1. Returns ------- Pipeline Lazy function evaluation :py:class:`Pipeline` for `func`. See also -------- Pipeline Examples -------- Apply the pipeline decorator to your image processing function. >>> @pipeline ... def color_channel(image, channel): ... return image[channel, :, :] In order to preserve the original function's doc string (i. e. do not add a note saying that it was made lazy), use the decorator like so: >>> @pipeline(retain_doc=True) ... def color_channel(image, channel): ... '''This doc string will not be changed''' ... return image[channel, :, :] Passing a Slicerator the function returns a Pipeline that "lazily" applies the function when the images come out. Different functions can be applied to the same underlying images, creating independent objects. >>> red_images = color_channel(images, 0) >>> green_images = color_channel(images, 1) Pipeline functions can also be composed. >>> @pipeline ... def rescale(image): ... return (image - image.min())/image.ptp() >>> rescale(color_channel(images, 0)) The function can still be applied to ordinary images. The decorator only takes affect when a Slicerator object is passed. >>> single_img = images[0] >>> red_img = red_channel(single_img) # normal behavior Pipeline functions can take more than one slicerator. >>> @pipeline(ancestor_count=2) ... def sum_offset(img1, img2, offset): ... return img1 + img2 + offset """ def wrapper(f): return _pipeline(f, **kwargs) if func is None: return wrapper else: return wrapper(func)
def _pipeline(func_or_class, **kwargs): try: is_class = issubclass(func_or_class, Pipeline) except TypeError: is_class = False if is_class: return _pipeline_fromclass(func_or_class, **kwargs) else: return _pipeline_fromfunc(func_or_class, **kwargs) def _pipeline_fromclass(cls, retain_doc=False, ancestor_count=1): """Actual `pipeline` implementation Parameters ---------- func : class Class for lazy evaluation retain_doc : bool If True, don't modify `func`'s doc string to say that it has been made lazy ancestor_count : int or 'all', optional Number of inputs to the pipeline. Defaults to 1. Returns ------- Pipeline Lazy function evaluation :py:class:`Pipeline` for `func`. """ if ancestor_count == 'all': # subtract 1 for `self` ancestor_count = len(inspect.getfullargspec(cls).args) - 1 @wraps(cls) def process(*args, **kwargs): ancestors = args[:ancestor_count] args = args[ancestor_count:] all_pipe = all(hasattr(a, '_slicerator_flag') or isinstance(a, _pipeline_types) for a in ancestors) if all_pipe: return cls(*(ancestors + args), **kwargs) else: # Fall back on normal behavior of func, interpreting input # as a single image. return cls(*(tuple([a] for a in ancestors) + args), **kwargs)[0] if not retain_doc: if process.__doc__ is None: process.__doc__ = '' process.__doc__ = ("This function has been made lazy. When passed\n" "a Slicerator, it will return a \n" "Pipeline of the results. When passed \n" "any other objects, its behavior is " "unchanged.\n\n") + process.__doc__ process.__name__ = cls.__name__ return process def _pipeline_fromfunc(func, retain_doc=False, ancestor_count=1): """Actual `pipeline` implementation Parameters ---------- func : callable Function for lazy evaluation retain_doc : bool If True, don't modify `func`'s doc string to say that it has been made lazy ancestor_count : int or 'all', optional Number of inputs to the pipeline. Defaults to 1. Returns ------- Pipeline Lazy function evaluation :py:class:`Pipeline` for `func`. """ if ancestor_count == 'all': ancestor_count = len(inspect.getfullargspec(func).args) @wraps(func) def process(*args, **kwargs): ancestors = args[:ancestor_count] args = args[ancestor_count:] all_pipe = all(hasattr(a, '_slicerator_flag') or isinstance(a, _pipeline_types) for a in ancestors) if all_pipe: def proc_func(*x): return func(*(x + args), **kwargs) return Pipeline(proc_func, *ancestors) else: # Fall back on normal behavior of func, interpreting input # as a single image. return func(*(ancestors + args), **kwargs) if not retain_doc: if process.__doc__ is None: process.__doc__ = '' process.__doc__ = ("This function has been made lazy. When passed\n" "a Slicerator, it will return a \n" "Pipeline of the results. When passed \n" "any other objects, its behavior is " "unchanged.\n\n") + process.__doc__ process.__name__ = func.__name__ return process def propagate_attr(func): func._propagate_flag = True return func def index_attr(func): @wraps(func) def wrapper(obj, key, *args, **kwargs): indices = key_to_indices(key, len(obj))[0] if isinstance(indices, collections.abc.Iterable): return (func(obj, i, *args, **kwargs) for i in indices) else: return func(obj, indices, *args, **kwargs) wrapper._index_flag = True return wrapper class SliceableAttribute(object): """This class enables index-taking methods that are linked to a Slicerator object to remap their indices according to the Slicerator indices. It also enables fancy indexing, exactly like the Slicerator itself. The new attribute supports both calling and indexing to give identical results.""" def __init__(self, slicerator, attribute): self._ancestor = slicerator._ancestor self._len = slicerator._len self._get = attribute self._indices = slicerator.indices # make an independent copy @property def indices(self): # Advancing indices won't affect this new copy of self._indices. indices, self._indices = itertools.tee(iter(self._indices)) return indices def _map_index(self, key): if key < -self._len or key >= self._len: raise IndexError("Key out of range") try: abs_key = self._indices[key] except TypeError: key = key if key >= 0 else self._len + key for _, i in zip(range(key + 1), self.indices): abs_key = i return abs_key def __iter__(self): return (self._get(i) for i in self.indices) def __len__(self): return self._len def __call__(self, key, *args, **kwargs): if not (isinstance(key, slice) or isinstance(key, collections.abc.Iterable)): return self._get(self._map_index(key), *args, **kwargs) else: rel_indices, new_length = key_to_indices(key, len(self)) return (self[k] for k in rel_indices) def __getitem__(self, key): return self(key) # Based on https://github.com/soft-matter/slicerator # Original copyright and license information: # # Copyright (c) 2015, Daniel B. Allan # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # * Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # * Neither the name of the matplotlib project nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE.