#!/usr/bin/env python3.0

"""Abstract Base Classes (ABCs) experiment.  See PEP 3119.

Note: this code depends on an unsubmitted patch:
http://python.org/sf/1708353.
"""

__author__ = "Guido van Rossum <guido@python.org>"

import sys
import inspect
import itertools


### ABC SUPPORT FRAMEWORK ###


def abstractmethod(funcobj):
    """A decorator indicating abstract methods.

    Requires that the metaclass is ABCMeta or derived from it.  A
    class that has a metaclass derived from ABCMeta cannot be
    instantiated unless all of its abstract methods are overridden.
    The abstract methods can be called using any of the the normal
    'super' call mechanisms.

    Usage:

        class C(metaclass=ABCMeta):
            @abstractmethod
            def my_abstract_method(self, ...):
                ...
    """
    funcobj.__isabstractmethod__ = True
    return funcobj


class _Abstract(object):

    """Helper class inserted into the bases by ABCMeta (using _fix_bases()).

    You should never need to explicitly subclass this class.
    """

    def __new__(cls, *args, **kwds):
        am = cls.__dict__.get("__abstractmethods__")
        if am:
            raise TypeError("can't instantiate abstract class %s "
                            "with abstract methods %s" %
                            (cls.__name__, ", ".join(sorted(am))))
        return super(_Abstract, cls).__new__(cls, *args, **kwds)


def _fix_bases(bases):
    """Helper method that inserts _Abstract in the bases if needed."""
    for base in bases:
        if issubclass(base, _Abstract):
            # _Abstract is already a base (maybe indirectly)
            return bases
    if object in bases:
        # Replace object with _Abstract
        return tuple([_Abstract if base is object else base
                      for base in bases])
    # Append _Abstract to the end
    return bases + (_Abstract,)


class ABCMeta(type):

    """Metaclass for defining Abstract Base Classes (ABCs).

    Use this metaclass to create an ABC.  An ABC can be subclassed
    directly, and then acts as a mix-in class.  You can also register
    unrelated concrete classes (even built-in classes) and unrelated
    ABCs as 'virtual subclasses' -- these and their descendants will
    be considered subclasses of the registering ABC by the built-in
    issubclass() function, but the registering ABC won't show up in
    their MRO (Method Resolution Order) nor will method
    implementations defined by the registering ABC be callable (not
    even via super()).

    """

    # A global counter that is incremented each time a class is
    # registered as a virtual subclass of anything.  It forces the
    # negative cache to be cleared before its next use.
    __invalidation_counter = 0

    def __new__(mcls, name, bases, namespace):
        bases = _fix_bases(bases)
        cls = super(ABCMeta, mcls).__new__(mcls, name, bases, namespace)
        # Compute set of abstract method names
        abstracts = {name
                     for name, value in namespace.items()
                     if getattr(value, "__isabstractmethod__", False)}
        for base in bases:
            for name in getattr(base, "__abstractmethods__", set()):
                value = getattr(cls, name, None)
                if getattr(value, "__isabstractmethod__", False):
                    abstracts.add(name)
        cls.__abstractmethods__ = abstracts
        # Set up inheritance registry
        cls.__abc_registry__ = set()
        cls.__abc_cache__ = set()
        cls.__abc_negative_cache__ = set()
        cls.__abc_negative_cache_version__ = ABCMeta.__invalidation_counter
        return cls

    def register(cls, subclass):
        """Register a virtual subclass of an ABC."""
        if not isinstance(cls, type):
            raise TypeError("Can only register classes")
        if issubclass(subclass, cls):
            return  # Already a subclass
        # Subtle: test for cycles *after* testing for "already a subclass";
        # this means we allow X.register(X) and interpret it as a no-op.
        if issubclass(cls, subclass):
            # This would create a cycle, which is bad for the algorithm below
            raise RuntimeError("Refusing to create an inheritance cycle")
        cls.__abc_registry__.add(subclass)
        ABCMeta.__invalidation_counter += 1  # Invalidate negative cache

    def _dump_registry(cls, file=None):
        """Debug helper to print the ABC registry."""
        if file is None:
            file = sys.stdout
        print("Class: %s.%s" % (cls.__module__, cls.__name__), file=file)
        print("Inv.counter: %s" % ABCMeta.__invalidation_counter, file=file)
        for name in sorted(cls.__dict__.keys()):
            if name.startswith("__abc_"):
                value = getattr(cls, name)
                print("%s: %r" % (name, value), file=file)

    def __instancecheck__(cls, instance):
        """Override for isinstance(instance, cls)."""
        return any(cls.__subclasscheck__(c)
                   for c in {instance.__class__, type(instance)})

    def __subclasscheck__(cls, subclass):
        """Override for issubclass(subclass, cls)."""
        # Check cache
        if subclass in cls.__abc_cache__:
            return True
        # Check negative cache; may have to invalidate
        if cls.__abc_negative_cache_version__ < ABCMeta.__invalidation_counter:
            # Invalidate the negative cache
            cls.__abc_negative_cache_version__ = ABCMeta.__invalidation_counter
            cls.__abc_negative_cache__ = set()
        elif subclass in cls.__abc_negative_cache__:
            return False
        # Check if it's a direct subclass
        if cls in subclass.__mro__:
            cls.__abc_cache__.add(subclass)
            return True
        # Check if it's a subclass of a registered class (recursive)
        for rcls in cls.__abc_registry__:
            if issubclass(subclass, rcls):
                cls.__abc_registry__.add(subclass)
                return True
        # Check if it's a subclass of a subclass (recursive)
        for scls in cls.__subclasses__():
            if issubclass(subclass, scls):
                cls.__abc_registry__.add(subclass)
                return True
        # No dice; update negative cache
        cls.__abc_negative_cache__.add(subclass)
        return False


### ONE TRICK PONIES ###


class Hashable(metaclass=ABCMeta):

    """A hashable has one method, __hash__()."""

    @abstractmethod
    def __hash__(self):
        return 0


class Iterable(metaclass=ABCMeta):

    """An iterable has one method, __iter__()."""

    @abstractmethod
    def __iter__(self):
        return _EmptyIterator()


class Iterator(Iterable):

    """An iterator has two methods, __iter__() and __next__()."""

    @abstractmethod
    def __next__(self):
        raise StopIteration

    def __iter__(self):  # Concrete!  This should always return self
        return self


class _EmptyIterator(Iterator):

    """Implementation detail used by Iterable.__iter__()."""

    def __next__(self):
        # This will call Iterator.__next__() which will raise StopIteration.
        return super(_EmptyIterator, self).__next__()
        # Or: return Iterator.__next__(self)
        # Or: raise StopIteration


class Sized(metaclass=ABCMeta):

    @abstractmethod
    def __len__(self):
        return 0


class Container(metaclass=ABCMeta):

    """A container has a __contains__() method."""

    @abstractmethod
    def __contains__(self, elem):
        return False

class Searchable(Container):

    """A container whose __contains__ accepts sequences too."""

    # XXX This is an experiment.  Is it worth distinguishing?
    # Perhaps later, when we have type annotations so you can write
    # Container[T], we can do this:
    #
    # class Container(metaclass=ABCMeta):
    #     def __contains__(self, val: T) -> bool: ...
    #
    # class Searchable(Container):
    #     def __contains__(self, val: T | Sequence[T]) -> bool: ...


### SETS ###


class Set(Sized, Iterable, Container):

    """A plain set is a finite, iterable container.

    This class provides concrete generic implementations of all
    methods except for __len__, __iter__ and __contains__.

    To override the comparisons (presumably for speed, as the
    semantics are fixed), all you have to do is redefine __le__ and
    then the other operations will automatically follow suit.
    """

    def __le__(self, other):
        if not isinstance(other, Set):
            return NotImplemented
        if len(self) > len(other):
            return False
        for elem in self:
            if elem not in other:
                return False
        return True

    def __lt__(self, other):
        if not isinstance(other, Set):
            return NotImplemented
        return len(self) < len(other) and self.__le__(other)

    def __eq__(self, other):
        if not isinstance(other, Set):
            return NotImplemented
        return len(self) == len(other) and self.__le__(other)

    @classmethod
    def _from_iterable(cls, it):
        return frozenset(it)

    def __and__(self, other):
        if not isinstance(other, Iterable):
            return NotImplemented
        return self._from_iterable(value for value in other if value in self)

    def __or__(self, other):
        if not isinstance(other, Iterable):
            return NotImplemented
        return self._from_iterable(itertools.chain(self, other))

    def __sub__(self, other):
        if not isinstance(other, Set):
            if not isinstance(other, Iterable):
                return NotImplemented
            other = self._from_iterable(other)
        return self._from_iterable(value for value in self
                                   if value not in other)

    def __xor__(self, other):
        if not isinstance(other, Set):
            if not isinstance(other, Iterable):
                return NotImplemented
            other = self._from_iterable(other)
        return (self - other) | (other - self)

    def _hash(self):
        """The hash value must match __eq__.

        All sets ought to compare equal if they contain the same
        elements, regardless of how they are implemented, and
        regardless of the order of the elements; so there's not much
        freedom for __eq__ or __hash__.  We match the algorithm used
        by the built-in frozenset type.
        """
        MAX = sys.maxint
        MASK = 2 * MAX + 1
        n = len(self)
        h = 1927868237 * (n + 1)
        h &= MASK
        for x in self:
            hx = hash(x)
            h ^= (hx ^ (hx << 16) ^ 89869747)  * 3644798167
            h &= MASK
        h = h * 69069 + 907133923
        h &= MASK
        if h > MAX:
            h -= MASK + 1
        if h == -1:
            h = 590923713
        return h


# XXX Should this derive from Set instead of from ComposableSet?
class MutableSet(Set):

    @abstractmethod
    def add(self, value):
        """Return True if it was added, False if already there."""
        raise NotImplementedError

    @abstractmethod
    def discard(self, value):
        """Return True if it was deleted, False if not there."""
        raise NotImplementedError

    def pop(self):
        """Return the popped value.  Raise KeyError if empty."""
        it = iter(self)
        try:
            value = it.__next__()
        except StopIteration:
            raise KeyError
        self.discard(value)
        return value

    def toggle(self, value):
        """Return True if it was added, False if deleted."""
        # XXX This implementation is not thread-safe
        if value in self:
            self.discard(value)
            return False
        else:
            self.add(value)
            return True

    def clear(self):
        """This is slow (creates N new iterators!) but effective."""
        try:
            while True:
                self.pop()
        except KeyError:
            pass

    def __ior__(self, it: Iterable):
        for value in it:
            self.add(value)
        return self

    def __iand__(self, c: Container):
        for value in self:
            if value not in c:
                self.discard(value)
        return self

    def __ixor__(self, it: Iterable):
        # This calls toggle(), so if that is overridded, we call the override
        for value in it:
            self.toggle(it)
        return self

    def __isub__(self, it: Iterable):
        for value in it:
            self.discard(value)
        return self


### MAPPINGS ###

# XXX Get rid of _BasicMapping and view types

class _BasicMapping(Container, Iterable):

    @abstractmethod
    def __getitem__(self, key):
        raise KeyError

    def get(self, key, default=None):
        try:
            return self[key]
        except KeyError:
            return default

    def __contains__(self, key):
        try:
            self[key]
            return True
        except KeyError:
            return False

    def keys(self):
        return KeysView(self)

    def items(self):
        return ItemsView(self)

    def values(self):
        return ValuesView(self)


class _MappingView:

    def __new__(cls, *args):
        return object.__new__(cls)

    def __init__(self, mapping):
        self._mapping = mapping


class KeysView(_MappingView, Container):

    def __iter__(self):
        for key in self._mapping:
            yield key

    def __contains__(self, key):
        return key in self._mapping


class ItemsView(_MappingView, Container):

    def __iter__(self):
        for key in self._mapping:
            yield key, self._mapping[key]

    def __contains__(self, item):
        try:
            key, value = item
        except:
            return False
        try:
            val = self._mapping[key]
        except KeyError:
            return False
        return value == val


class ValuesView(_MappingView):

    # Note: does not derive from Container, does not implement __contains__!

    def __iter__(self):
        for key in self._mapping:
            yield self._mapping[key]


class Mapping(_BasicMapping, Sized):

    def keys(self):
        return KeysView(self)

    def items(self):
        return ItemsView(self)

    def values(self):
        return ValuesView(self)

    def __eq__(self, other):
        if not isinstance(other, Mapping):
            return NotImplemented
        if len(other) != len(self):
            return False
        # XXX Or: for key, value1 in self.items(): ?
        for key in self:
            value1 = self[key]
            try:
                value2 = other[key]
            except KeyError:
                return False
            if value1 != value2:
                return False
        return True


class _MappingView(_MappingView, Sized):

    def __len__(self):
        return len(self._mapping)


class KeysView(_MappingView, KeysView, Set):
    pass


class ItemsView(_MappingView, ItemsView, Set):
    pass


class ValuesView(_MappingView, ValuesView):

    def __eq__(self, other):
        if not (isinstance(other, Sized) and isinstance(other, Iterable)):
            return NotImplemented
        if len(self) != len(other):
            return False
        # XXX This is slow. Sometimes this could be optimized, but these
        # are the semantics: we can't depend on the values to be hashable
        # or comparable.
        o_values = list(other)
        for value in self:
            for i, o_value in enumerate(o_values):
                if value == o_value:
                    del o_values[i]
                    break
            else:
                return False
        assert not o_values  # self must have mutated somehow
        return True

    def __contains__(self, value):
        # This is slow, but these are the semantics.
        for elem in self:
            if elem == value:
                return True
        return False


class MutableMapping(Mapping):

    @abstractmethod
    def __setitem__(self, key):
        raise NotImplementedError

    @abstractmethod
    def __delitem__(self, key):
        raise NotImplementedError

    __marker = object()

    def pop(self, key, default=__marker):
        try:
            value = self[key]
        except KeyError:
            if default is self.__marker:
                raise
            return default
        else:
            del self[key]
            return value

    def popitem(self):
        try:
            key = next(iter(self))
        except StopIteration:
            raise KeyError
        value = self[key]
        del self[key]
        return key, value

    def clear(self):
        try:
            while True:
                self.popitem()
        except KeyError:
            pass

    def update(self, other=(), **kwds):
        if isinstance(other, Mapping):
            for key in other:
                self[key] = other[key]
        elif hasattr(other, "keys"):
            for key in other.keys():
                self[key] = other[key]
        else:
            for key, value in other:
                self[key] = value
        for key, value in kwds.items():
            self[key] = value


### SEQUENCES ###


def _index(i):
    # Internal helper to raise TypeError for non-integer(-castable) values
    if not isinstance(i, int):
        if not hasattr(i, "__index__"):
            raise TypeError
        i = i.__index__()
        if not isinstance(i, int):
            raise TypeError
    return i


def _slice(slc, size):
    # Internal helper to normalize a slice into start, stop, step
    # ints; arguments are a slice object and the length of the
    # sequence.
    # XXX A problem this shares with Python 2: a[n-1:-1:-1] (where n
    # is len(a)) returns an empty slice because the stop value is
    # normalized to n-1.
    start, stop, step = slc.start, slc.stop, slc.step

    if start is not None:
        start = _index(start)
        if start < 0:
            start += size
    if stop is not None:
        stop = _index(stop)
        if stop < 0:
            stop += size
    if step is not None:
        step = _index(step)

    if step is None:
        step = 1
    if step == 0:
        raise ValueError
    if step < 0:
        if start is None:
            start = size - 1
        if stop is None:
            stop = -1
    else:
        if start is None:
            start = 0
        if stop is None:
            stop = size

    return start, stop, step



class Sequence(Sized, Iterable, Container):

    """A minimal sequence.

    I'm not bothering with an unsized version; I don't see a use case.

    Concrete subclasses must override __new__ or __init__, __getitem__,
    and __len__; they might want to override __add__ and __mul__.  The
    constructor signature is expected to support a single argument
    giving an iterable providing the elements.
    """

    @abstractmethod
    def __getitem__(self, index):
        if isinstance(index, slice):
            start, stop, step = _slice(index, len(self))
            return self.__class__(self[i] for i in range(start, stop, step))
        else:
            index = _index(index)
            raise IndexError

    @abstractmethod
    def __len__(self):
        return 0

    def __iter__(self):
        i = 0
        while i < len(self):
            yield self[i]
            i += 1

    def __contains__(self, value):
        for val in self:
            if val == value:
                return True
        return False

    # XXX Do we want all or some of the following?

    def __reversed__(self):
        i = len(self)
        while i > 0:
            i -= 1
            yield self[i]

    def index(self, value):
        # XXX Should we add optional start/stop args?  Probably not.
        for i, elem in enumerate(self):
            if elem == value:
                return i
        raise ValueError

    def count(self, value):
        return sum(1 for elem in self if elem == value)

    def __add__(self, other):
        if not isinstance(other, Sequence):
            return NotImplemented
        return self.__class__(elem for seq in (self, other) for elem in seq)

    def __mul__(self, repeat):
        # XXX Looks like we need an ABC to indicate integer-ness...
        if not isinstance(repeat, int) and not hasattr(repeat, "__index__"):
            return NotImplemented
        repeat = _index(repeat)
        return self.__class__(elem for i in range(repeat) for elem in self)

    # XXX Should we derive from PartiallyOrdered or TotallyOrdered?
    # That depends on the items.  What if the items aren't orderable?

    def __eq__(self, other):
        if not isinstance(other, Sequence):
            return NotImplemented
        if len(self) != len(other):
            return False
        for a, b in zip(self, other):
            if a == b:
                continue
            return False
        return len(self) == len(other)

    def __lt__(self, other):
        if not isinstance(other, Sequence):
            return NotImplemented
        for a, b in zip(self, other):
            if a == b:
                continue
            return a < b
        return len(self) < len(other)

    def __le__(self, other):
        if not isinstance(other, Sequence):
            return NotImplemented
        for a, b in zip(self, other):
            if a == b:
                continue
            return a < b
        return len(self) <= len(other)


class MutableSequence(Sequence):

    @abstractmethod
    def __setitem__(self, i, value):
        raise NotImplementedError

    @abstractmethod
    def __delitem__(self, i, value):
        raise NotImplementedError

    @abstractmethod
    def insert(self, i, value):
        raise NotImplementedError

    def append(self, value):
        self.insert(len(self), value)

    def reverse(self):
        n = len(self)
        for i in range(n//2):
            j = n-i-1
            self[i], self[j] = self[j], self[i]

    def extend(self, it):
        for x in it:
            self.append(x)

    def pop(self, i=None):
        if i is None:
            i = len(self) - 1
        value = self[i]
        del self[i]
        return value

    def remove(self, value):
        for i in range(len(self)):
            if self[i] == value:
                del self[i]
                return
        raise ValueError



### PRE-DEFINED REGISTRATIONS ###

Hashable.register(int)
Hashable.register(float)
Hashable.register(complex)
Hashable.register(basestring)
Hashable.register(tuple)
Hashable.register(frozenset)
Hashable.register(type)

Set.register(frozenset)
MutableSet.register(set)

MutableMapping.register(dict)

Sequence.register(tuple)
Sequence.register(basestring)
MutableSequence.register(list)
MutableSequence.register(bytes)


### ADAPTERS ###

# This is just an example, not something to go into the stdlib


class AdaptToSequence(Sequence):

    def __new__(cls, adaptee):
        if not hasattr(adaptee, "__getitem__"):
            # Hack so that the self.__class__(<generator>) calls above work
            adaptee = list(adaptee)
        obj = Sequence.__new__(cls)
        obj.adaptee = adaptee
        return obj

    def __getitem__(self, index):
        if isinstance(index, slice):
            return super(AdaptToSequence, self).__getitem__(index)
        return self.adaptee[_index(index)]

    def __len__(self):
        return len(self.adaptee)


class AdaptToMapping(Mapping):

    def __new__(cls, adaptee):
        self = Mapping.__new__(cls)
        self.adaptee = adaptee
        return self

    def __getitem__(self, index):
        return self.adaptee[index]

    def __len__(self):
        return len(self.adaptee)

    def __iter__(self):
        return iter(self.adaptee)


class AdaptToSet(Set):

    def __new__(cls, adaptee):
        self = Set.__new__(cls)
        self.adaptee = adaptee
        return self

    def __contains__(self, elem):
        return elem in self.adaptee

    def __iter__(self):
        return iter(self.adaptee)

    def __len__(self):
        return len(self.adaptee)


### OVERLOADING ###

# This is a modest alternative proposal to PEP 3124.  It uses
# issubclass() exclusively meaning that any issubclass() overloading
# automatically works.  If accepted it probably ought to go into a
# separate module (overloading.py?) as it has nothing to do directly
# with ABCs.  The code here is an evolution from my earlier attempt in
# sandbox/overload/overloading.py.


class overloadable:

    """An implementation of overloadable functions.

    Usage example:

    @overloadable
    def flatten(x):
        yield x

    @flatten.overload
    def _(it: Iterable):
        for x in it:
            yield x

    @flatten.overload
    def _(x: basestring):
        yield x

    """

    def __init__(self, default_func):
        # Decorator to declare new overloaded function.
        self.registry = {}
        self.cache = {}
        self.default_func = default_func

    def __get__(self, obj, cls=None):
        if obj is None:
            return self
        return new.instancemethod(self, obj)

    def overload(self, func):
        """Decorator to overload a function using its argument annotations."""
        self.register_func(self.extract_types(func), func)
        if func.__name__ == self.default_func.__name__:
            return self
        else:
            return func

    def extract_types(self, func):
        """Helper to extract argument annotations as a tuple of types."""
        args, varargs, varkw, defaults, kwonlyargs, kwdefaults, annotations = \
              inspect.getfullargspec(func)
        return tuple(annotations.get(arg, object) for arg in args)

    def register_func(self, types, func):
        """Helper to register an implementation."""
        self.registry[types] = func
        self.cache = {} # Clear the cache (later we might optimize this).

    def __call__(self, *args):
        """Call the overloaded function."""
        types = tuple(arg.__class__ for arg in args)
        funcs = self.cache.get(types)
        if funcs is None:
            self.cache[types] = funcs = list(self.find_funcs(types))
        return funcs[0](*args)

    def find_funcs(self, types):
        """Yield the appropriate overloaded functions, in order."""
        func = self.registry.get(types)
        if func is not None:
            # Easy case -- direct hit in registry.
            yield func
            return

        candidates = [cand
                      for cand in self.registry
                      if self.implies(types, cand)]

        if not candidates:
            # Easy case -- return the default function
            yield self.default_func
            return

        if len(candidates) == 1:
            # Easy case -- return this and the default function
            yield self.registry[candidates[0]]
            yield self.default_func
            return

##         # Perhaps all candidates have the same implementation?
##         # XXX What do we care?
##         funcs = set(self.registry[cand] for cand in candidates)
##         if len(funcs) == 1:
##             yield funcs.pop()
##             yield self.default_func
##             return

        candidates.sort(self.comparator)  # Sort on a partial ordering!
        while candidates:
            cand = candidates.pop(0)
            if all(self.implies(cand, c) for c in candidates):
                yield self.registry[cand]
            else:
                yield self.raise_ambiguity
                break
        else:
            yield self.default_func

    def comparator(self, xs, ys):
        return self.implies(ys, xs) - self.implies(xs, ys)

    def implies(self, xs, ys):
        return len(xs) == len(ys) and all(issubclass(x, y)
                                          for x, y in zip(xs, ys))

    def raise_ambiguity(self, *args):
        # XXX Should be more specific
        raise TypeError("ambiguous signature of overloadable function")

    def raise_exhausted(self, *args):
        # XXX Should be more specific
        raise TypeError("no remaining candidates for overloadable function")
