"""A class to represent sets in Python.
This class implements sets as dictionaries whose values are ignored.
The usual operations (union, intersection, deletion, etc.) are
provided as both methods and operators.  The only unusual feature of
this class is that once a set's hash code has been calculated (for
example, once it has been used as a dictionary key, or as an element
in another set), that set 'freezes', and becomes immutable.  See
PEP-0218 for a full discussion.
"""

__version__ = "$Revision$"
__author__  = "$Author$"
__date__    = "$Date$"

from copy import deepcopy

class Set:

    # Displayed when operation forbidden because set has been frozen
    _Frozen_Msg = "Set is frozen: %s not permitted"

    #----------------------------------------
    def __init__(self, seq=None, sort_repr=0):

        """Construct a set, optionally initializing it with elements
        drawn from a sequence.  If 'sort_repr' is true, the set's
        elements are displayed in sorted order.  This slows down
        conversion, but simplifies comparison during testing.  The
        'hashcode' element is given a non-None value the first time
        the set's hashcode is calculated; the set is frozen
        thereafter."""

        self.elements = {}
        self.sort_repr = sort_repr
        if seq is not None:
            for x in seq:
                self.elements[x] = None
        self.hashcode = None

    #----------------------------------------
    def __str__(self):
        """Convert set to string."""
        content = self.elements.keys()
        if self.sort_repr:
            content.sort()
        return 'Set(' + `content` + ')'

    #----------------------------------------
    # '__repr__' returns the same thing as '__str__'
    __repr__ = __str__

    #----------------------------------------
    def __len__(self):
        """Return number of elements in set."""
        return len(self.elements)

    #----------------------------------------
    def __contains__(self, item):
        """Test presence of value in set."""
        return item in self.elements

    #----------------------------------------
    def __iter__(self):
        """Return iterator for enumerating set elements.  This is a
        keys iterator for the underlying dictionary."""
        return self.elements.iterkeys()

    #----------------------------------------
    def __cmp__(self, other):
        """Compare one set with another.  Sets may only be compared
        with sets; ordering is determined by the keys in the
        underlying dictionary."""
        if not isinstance(other, Set):
            raise ValueError, "Sets can only be compared to sets"
        return cmp(self.elements, other.elements)

    #----------------------------------------
    def __hash__(self):

        """Calculate hash code for set by xor'ing hash codes of set
        elements.  This algorithm ensures that the hash code does not
        depend on the order in which elements are added to the
        code."""

        # If set already has hashcode, the set has been frozen, so
        # code is still valid.
        if self.hashcode is not None:
            return self.hashcode

        # Combine hash codes of set elements to produce set's hash code.
        self.hashcode = 0
        for elt in self.elements:
            self.hashcode ^= hash(elt)
        return self.hashcode

    #----------------------------------------
    def is_frozen(self):

        """Report whether set is frozen or not.  A frozen set is one
        whose hash code has already been calculated.  Frozen sets may
        not be mutated, but unfrozen sets can be."""

        return self.hashcode is not None

    #----------------------------------------
    def __copy__(self):
        """Return a shallow copy of the set."""
        result = Set()
        result.elements = self.elements.copy()
        return result

    #----------------------------------------
    # Define 'copy' method as readable alias for '__copy__'.
    copy = __copy__

    #----------------------------------------
    def __deepcopy__(self, memo):
        result          = Set()
        memo[id(self)]  = result
        result.elements = deepcopy(self.elements, memo)
        return result

    #----------------------------------------
    def clear(self):
        """Remove all elements of unfrozen set."""
        if self.hashcode is not None:
            raise ValueError, Set._Frozen_Msg % "clearing"
        self.elements.clear()

    #----------------------------------------
    def union_update(self, other):
        """Update set with union of its own elements and the elements
        in another set."""

        self._binary_sanity_check(other, "updating union")
        self.elements.update(other.elements)
        return self

    #----------------------------------------
    def union(self, other):
        """Create new set whose elements are the union of this set's
        and another's."""

        self._binary_sanity_check(other)
        result = self.__copy__()
        result.union_update(other)
        return result

    #----------------------------------------
    def intersect_update(self, other):
        """Update set with intersection of its own elements and the
        elements in another set."""

        self._binary_sanity_check(other, "updating intersection")
        new_elements = {}
        for elt in self.elements:
            if elt in other.elements:
                new_elements[elt] = None
        self.elements = new_elements
        return self

    #----------------------------------------
    def intersect(self, other):
        """Create new set whose elements are the intersection of this
        set's and another's."""

        self._binary_sanity_check(other)
        if len(self) <= len(other):
            little, big = self, other
        else:
            little, big = other, self
        result = Set()
        for elt in little.elements:
            if elt in big.elements:
                result.elements[elt] = None
        return result

    #----------------------------------------
    def sym_difference_update(self, other):
        """Update set with symmetric difference of its own elements
        and the elements in another set.  A value 'x' is in the result
        if it was originally present in one or the other set, but not
        in both."""

        self._binary_sanity_check(other, "updating symmetric difference")
        self.elements = self._raw_sym_difference(self.elements, other.elements)
        return self

    #----------------------------------------
    def sym_difference(self, other):
        """Create new set with symmetric difference of this set's own
        elements and the elements in another set.  A value 'x' is in
        the result if it was originally present in one or the other
        set, but not in both."""

        self._binary_sanity_check(other)
        result = Set()
        result.elements = self._raw_sym_difference(self.elements, other.elements)
        return result

    #----------------------------------------
    def difference_update(self, other):
        """Remove all elements of another set from this set."""

        self._binary_sanity_check(other, "updating difference")
        new_elements = {}
        for elt in self.elements:
            if elt not in other.elements:
                new_elements[elt] = None
        self.elements = new_elements
        return self

    #----------------------------------------
    def difference(self, other):
        """Create new set containing elements of this set that are not
        present in another set."""

        self._binary_sanity_check(other)
        result = Set()
        for elt in self.elements:
            if elt not in other.elements:
                result.elements[elt] = None
        return result

    #----------------------------------------
    # Arithmetic forms of operations
    __or__      = union
    __ror__     = union
    __ior__     = union_update
    __and__     = intersect
    __rand__    = intersect
    __iand__    = intersect_update
    __xor__     = sym_difference
    __rxor__    = sym_difference
    __ixor__    = sym_difference_update
    __sub__     = difference
    __rsub__    = difference
    __isub__    = difference_update

    #----------------------------------------
    def add(self, item):
        """Add an item to a set.  This has no effect if the item is
        already present."""

        if self.hashcode is not None:
            raise ValueError, Set._Frozen_Msg % "adding an element"
        self.elements[item] = None

    #----------------------------------------
    def update(self, iterable):
        """Add all values from an iteratable (such as a tuple, list,
        or file) to this set."""

        if self.hashcode is not None:
            raise ValueError, Set._Frozen_Msg % "adding an element"
        for item in iterable:
            self.elements[item] = None

    #----------------------------------------
    def remove(self, item):
        """Remove an element from a set if it is present, or raise a
        LookupError if it is not."""

        if self.hashcode is not None:
            raise ValueError, Set._Frozen_Msg % "removing an element"
        try:
            del self.elements[item]
        except KeyError:
            raise LookupError, `item`

    #----------------------------------------
    def discard(self, item):
        """Remove an element from a set if it is present, or do
        nothing if it is not."""

        if self.hashcode is not None:
            raise ValueError, Set._Frozen_Msg % "removing an element"
        try:
            del self.elements[item]
        except KeyError:
            pass

    #----------------------------------------
    def popitem(self):
        """Remove and return a randomly-chosen set element."""

        try:
            (key, value) = self.elements.popitem()
            return key
        except KeyError:
            raise LookupError, "set is empty"

    #----------------------------------------
    def is_subset_of(self, other):
        """Reports whether other set contains this set."""
        if not isinstance(other, Set):
            raise ValueError, "Subset tests only permitted between sets"
        for element in self.elements:
            if element not in other.elements:
                return 0
        return 1

    #----------------------------------------
    def contains_all_of(self, other):
        """Report whether other subset is subset of this set."""
        if not isinstance(other, Set):
            raise ValueError, "Subset tests only permitted between sets"
        for element in other.elements:
            if element not in self.elements:
                return 0
        return 1

    #----------------------------------------
    # Check that the other argument to a binary operation is also a
    # set, and that this set is still mutable (if appropriate),
    # raising a ValueError if either condition is not met.
    def _binary_sanity_check(self, other, updating_op=''):
        if updating_op and (self.hashcode is not None):
            raise ValueError, Set._Frozen_Msg % updating_op
        if not isinstance(other, Set):
            raise ValueError, "Binary operation only permitted between sets"

    #----------------------------------------
    # Calculate the symmetric difference between the keys in two
    # dictionaries with don't-care values.
    def _raw_sym_difference(self, left, right):
        result = {}
        for elt in left:
            if elt not in right:
                result[elt] = None
        for elt in right:
            if elt not in left:
                result[elt] = None
        return result

#----------------------------------------------------------------------
# Rudimentary self-tests
#----------------------------------------------------------------------

if __name__ == "__main__":

    # Empty set
    red = Set()
    assert `red` == "Set([])", "Empty set: %s" % `red`

    # Unit set
    green = Set((0,), 1)
    assert `green` == "Set([0])", "Unit set: %s" % `green`

    # 3-element set
    blue = Set([0, 1, 2], 1)
    assert `blue` == "Set([0, 1, 2])", "3-element set: %s" % `blue`

    # 2-element set with other values
    black = Set([0, 5], 1)
    assert `black` == "Set([0, 5])", "2-element set: %s" % `black`

    # All elements from all sets
    white = Set([0, 1, 2, 5], 1)
    assert `white` == "Set([0, 1, 2, 5])", "4-element set: %s" % `white`

    # Add element to empty set
    red.add(9)
    assert `red` == "Set([9])", "Add to empty set: %s" % `red`

    # Remove element from unit set
    red.remove(9)
    assert `red` == "Set([])", "Remove from unit set: %s" % `red`

    # Remove element from empty set
    try:
        red.remove(0)
        assert 0, "Remove element from empty set: %s" % `red`
    except LookupError:
        pass

    # Length
    assert len(red) == 0,   "Length of empty set"
    assert len(green) == 1, "Length of unit set"
    assert len(blue) == 3,  "Length of 3-element set"

    # Compare
    assert green == Set([0]), "Equality failed"
    assert green != Set([1]), "Inequality failed"

    # Union
    assert blue  | red   == blue,  "Union non-empty with empty"
    assert red   | blue  == blue,  "Union empty with non-empty"
    assert green | blue  == blue,  "Union non-empty with non-empty"
    assert blue  | black == white, "Enclosing union"

    # Intersection
    assert blue  & red   == red,   "Intersect non-empty with empty"
    assert red   & blue  == red,   "Intersect empty with non-empty"
    assert green & blue  == green, "Intersect non-empty with non-empty"
    assert blue  & black == green, "Enclosing intersection"

    # Symmetric difference
    assert red ^ green == green,        "Empty symdiff non-empty"
    assert green ^ blue == Set([1, 2]), "Non-empty symdiff"
    assert white ^ white == red,        "Self symdiff"

    # Difference
    assert red - green == red,           "Empty - non-empty"
    assert blue - red == blue,           "Non-empty - empty"
    assert white - black == Set([1, 2]), "Non-empty - non-empty"

    # In-place union
    orange = Set([])
    orange |= Set([1])
    assert orange == Set([1]), "In-place union"

    # In-place intersection
    orange = Set([1, 2])
    orange &= Set([2])
    assert orange == Set([2]), "In-place intersection"

    # In-place difference
    orange = Set([1, 2, 3])
    orange -= Set([2, 4])
    assert orange == Set([1, 3]), "In-place difference"

    # In-place symmetric difference
    orange = Set([1, 2, 3])
    orange ^= Set([3, 4])
    assert orange == Set([1, 2, 4]), "In-place symmetric difference"

    print "All tests passed"
