# Test the hotbuf module.
#
# $Id$
#
#  Copyright (C) 2006   Martin Blais <blais@furius.ca>
#  Licensed to PSF under a Contributor Agreement.
#

from hotbuf import hotbuf, BoundaryError
from struct import Struct, pack
import unittest
from cStringIO import StringIO
from test import test_support

CAPACITY = 1024
MSG = 'Martin Blais was here scribble scribble.'


#------------------------------------------------------------------------
#
class HotbufTestCase(unittest.TestCase):

    # Note: we don't use floats because comparisons will cause precision
    # errors due to the binary conversion.
    fmt = Struct('llci')

    def test_base( self ):
        # Create a new hotbuf
        self.assertRaises(ValueError, hotbuf, -1)
        self.assertRaises(ValueError, hotbuf, 0)
        b = hotbuf(CAPACITY)
        self.assertEquals(len(b), CAPACITY)
        self.assertEquals(b.capacity, CAPACITY)

        # Play with the position
        b.limit = 100
        self.assertEquals(b.position, 0)
        b.position = 10
        self.assertEquals(b.position, 10)
        self.assertEquals(len(b), 90)
        b.position = b.limit
        self.assertEquals(b.position, b.limit)
        def setposition( b, val ):
            b.position = val
        self.assertRaises(BoundaryError, setposition, b, -1)
        self.assertRaises(BoundaryError, setposition, b, b.limit + 1)
        self.assertRaises(BoundaryError, setposition, b, CAPACITY + 1)

        # Play with the limit
        b.position = 10
        b.limit = 100
        self.assertEquals(b.limit, 100)
        b.limit = 110
        self.assertEquals(b.limit, 110)
        def setlimit( b, val ):
            b.limit = val
        self.assertRaises(BoundaryError, setlimit, b, CAPACITY + 1)
        b.limit = b.position - 1
        self.assertEquals(b.position, b.limit)

        # Play with clear
        b.clear()
        self.assertEquals((b.position, b.limit), (0, CAPACITY))

        # Play with flip.
        b.position = 42
        b.limit = 104
        b.flip()
        self.assertEquals((b.position, b.limit), (0, 42))

        # Play with length.
        self.assertEquals(len(b), 42)
        b.position = 10
        self.assertEquals(len(b), 32)

        # Play with advance.
        self.assertEquals(b.position, 10)
        b.position += 32
        self.assertEquals(b.position, 42)

        self.assertRaises(BoundaryError, setposition, b, CAPACITY)

        # Play with setlen()
        b.clear()
        b.setlen(12)
        self.assertEquals((b.position, b.limit), (0, 12))
        b.position += 3
        b.setlen(12)
        self.assertEquals((b.position, b.limit), (3, 15))

    def test_compact( self ):
        b = hotbuf(CAPACITY)

        b.position = 100
        b.limit = 200
        b.compact()
        self.assertEquals((b.position, b.limit), (100, CAPACITY))

        # Compare the text that gets compacted.
        b.clear()
        b.position = 100
        b.putstr(MSG)
        b.limit = b.position
        b.position = 100
        b.compact()
        b.flip()
        self.assertEquals(str(b), MSG)

    def test_byte( self ):
        b = hotbuf(256)

        # Fill up the buffer with bytes.
        for x in xrange(256):
            b.putbyte(x)

        # Test overflow.
        self.assertRaises(BoundaryError, b.putbyte, 42)

        # Read all data from the buffer.
        b.flip()
        for x in xrange(256):
            nx = b.getbyte()
            assert nx == x

        # Test underflow.
        self.assertRaises(BoundaryError, b.putbyte, 42)

    def test_str( self ):
        b = hotbuf(256)

        # Write into the buffer
        b.putstr(MSG)
        b.flip()

        # Read back and assert message
        self.assertEquals(b.getstr(len(MSG)), MSG)

        # Test overflow.
        b.flip()
        self.assertRaises(BoundaryError, b.putstr, ' ' * 1000)

        # Test underflow.
        self.assertRaises(BoundaryError, b.getstr, 1000)

        # Test getting the rest of the string.
        b.clear()
        b.putstr(MSG)
        b.flip()
        s = b.getstr()
        self.assertEquals(s, MSG)

    def test_conversion( self ):
        b = hotbuf(CAPACITY)

        b.position = 100
        b.limit = 132

        self.assertEquals(len(b), 32)
        s = str(b)
        self.assertEquals(len(s), 32)

        r = repr(b)
        self.assert_(r.startswith('<hotbuf '))

    def test_compare( self ):
        b = hotbuf(CAPACITY)

    def test_pack( self ):
        ARGS = 42, 16, '@', 3
        # Pack to a string.
        s = self.fmt.pack(*ARGS)

        # Pack directly into the buffer and compare the strings.
        b = hotbuf(CAPACITY)
        self.fmt.pack_to(b, 0, *ARGS)
        b.limit = len(s)
        self.assertEquals(str(b), s)

    def test_pack_method(self):
        ARGS = 42, 16, '@', 3
        # Pack to a string.
        s = self.fmt.pack(*ARGS)

        # Pack directly into the buffer and compare the strings.
        b = hotbuf(CAPACITY)
        b.limit = len(s)
        b.pack(self.fmt, *ARGS)
        self.assertEquals(b.position, self.fmt.size)
        b.flip()
        self.assertEquals(str(b), s)

    def test_unpack( self ):
        ARGS = 42, 16, '@', 3
        b = hotbuf(CAPACITY)

        # Pack normally and put that string in the buffer.
        s = self.fmt.pack(*ARGS)
        b.putstr(s)

        # Unpack directly from the buffer and compare.
        b.flip()
        self.assertEquals(self.fmt.unpack_from(b), ARGS)

    def test_unpack_method( self ):
        ARGS = 42, 16, '@', 3
        b = hotbuf(CAPACITY)

        # Pack normally and put that string in the buffer.
        s = self.fmt.pack(*ARGS)
        b.putstr(s)

        # Unpack directly from the buffer and compare.
        b.flip()
        self.assertEquals(b.unpack(self.fmt), ARGS)
        self.assertEquals(b.position, self.fmt.size)

    def test_zerolen( self ):
        b = hotbuf(CAPACITY)
        b.limit = 0
        self.assertEquals(str(b), '')
        self.assertEquals(b.getstr(), '')

    def test_count(self):
        b = hotbuf(CAPACITY)
        b.putstr('abcddddd')
        b.flip()
        b.limit = 3
        self.assertEquals(b.count('ab'), 1)
        self.assertEquals(b.count('d'), 0)
        self.assertEquals(b.count('1' * 100), 0)
        b.limit = 4
        self.assertEquals(b.count('ab'), 1)
        self.assertEquals(b.count('d'), 1)
        self.assertEquals(b.count('1' * 100), 0)
        b.limit = 5
        self.assertEquals(b.count('ab'), 1)
        self.assertEquals(b.count('d'), 2)
        self.assertEquals(b.count('1' * 100), 0)
        b.position += 1
        self.assertEquals(b.count('ab'), 0)
        self.assertEquals(b.count('d'), 2)
        self.assertEquals(b.count('1' * 100), 0)
        b.limit += 1
        self.assertEquals(b.count('ab'), 0)
        self.assertEquals(b.count('d'), 3)
        self.assertEquals(b.count('1' * 100), 0)

    def test_find(self):
        b = hotbuf(CAPACITY)
        b.putstr('abcddddd')
        b.flip()
        b.limit = 3
        self.assertEquals(b.find('ab'), 0)
        self.assertEquals(b.find('d'), -1)
        self.assertEquals(b.find('1' * 100), -1)
        b.limit = 4
        self.assertEquals(b.find('ab'), 0)
        self.assertEquals(b.find('d'), 3)
        self.assertEquals(b.find('1' * 100), -1)
        b.limit = 5
        self.assertEquals(b.find('ab'), 0)
        self.assertEquals(b.find('d'), 3)
        self.assertEquals(b.find('1' * 100), -1)
        b.position += 1
        self.assertEquals(b.find('ab'), -1)
        self.assertEquals(b.find('d'), 2)
        self.assertEquals(b.find('1' * 100), -1)
        b.limit += 1
        self.assertEquals(b.find('ab'), -1)
        self.assertEquals(b.find('d'), 2)
        self.assertEquals(b.find('1' * 100), -1)

    def test_nonzero(self):
        b = hotbuf(CAPACITY)
        self.assertEquals(bool(b), True)
        b.position = b.limit
        self.assertEquals(bool(b), False)


#------------------------------------------------------------------------
#
class HotbufUseCases(unittest.TestCase):
    """
    Use cases for the hot buffer.

    This is not so much a test but a demonstration of how to implement
    common patterns of parsing using the hot buffer.  We're trying to
    come up with the typical use cases for this kind of object, in order
    to design this buffer object as efficiently as possible.  Once this
    is done, we will implement some of these common patterns in C, as
    generically as possible (e.g. for parsing netstrings faster).

    We need to be able to parse

    - line-based protocols
    - netstring protocols
    - work with multiple sockets

    as efficiently as possible.

    We're not entirely sure of the interface that this class should have
    here so we spell out all the operations, assuming that the hot buffer
    only contains a position and a limit, and that this limits which
    region is accessible to get contents out of it.

    The hot buffer is an fixed size buffer that contains a position and a
    limit:

       X-----------X-----------------X-------------------------X
       0           position          limit                     capacity

    Where the following invariant always holds:

          0 <= position <= limit <= capacity

    The hot buffer supports the buffer protocol, and exposes only the
    region between position and limit.


    The following basic operations are assumed available (we may provide
    combinations of those for common occurrences if needed):

    - reading and setting the position of the hot buffer

    - reading and setting the limit of the hot buffer

    - compact(): move the contents to the beginning of the buffer and
                 prepare to read more

        copy [hot.position, hot.limit] to location 0
        hot.position = hot.limit - hot.position
        hot.limit = hot.capacity

    - flip(): flip the active region to the beginning part

        hot.limit = hot.position
        hot.position = 0

    - getbyte():

        c = buf[self.position]
        self.position += 1
        return c

    - getbyte(relpos):

        c = buf[self.position]
        return c

    - getstr():

        s = buf[self.position:self.limit]
        self.position self.limit
        return s

    - putstr(s):

        buf += s
        self.position += len(s)

    Note: these codes are not exception-safe in terms of the buffer's
    position and limit, i.e. if there is an exception, the position and
    limit will not be reset correctly unless you do so, or, well, to
    their previous value--whether that is correct or incorrect depends
    on your application and the side-effects that may have occurred.
    """

    def parse_newline_delim( self, hot, read, process_line ):
        """
        Use case for newline-delimited data.
        """
        cr = ord('\r')

        # Initiallly put some data into the buffer.
        hot.putstr(read(len(hot)))
        hot.flip()

        # Look over the entire input.
        while 1:
            # Save the current position and limit
            abslimit = hot.limit
            mark_position = hot.position # setmark

            # Loop over the current buffer contents
            while hot:
                # Loop over all characters
                # If we got to the end of the line
                nidx = hot.find('\n')
                if nidx != -1:
                    # Calculate how much characters are needed to
                    # backup to remove the EOL marker
                    backup = 0

                    # Make sure we don't look before the first char
                    if nidx > 0:
                        if hot.getbyterel(nidx - 1) == cr:
                            backup = 1

                    # Restrict the window to the current line
                    hot.position = mark_position
                    hot.limit = mark_position + nidx - backup

                    # Process the line.
                    process_line(hot)

                    # Advance the buffer window to the rest of the
                    # buffer after the line
                    hot.limit = abslimit
                    hot.position += nidx + 1
                    mark_position = hot.position
                else:
                    break

            # Read more data in the buffer.
            hot.compact()
            s = read(len(hot))
            if not s:
                hot.flip()
                break # Finished the input, exit.
            hot.putstr(s)
            hot.flip()

        # Process the little bit at the end.
        if hot:
            process_line(hot)

    def test_newline_delim_data( self ):
        """
        Test for newline-delimited data.
        """
        inp = StringIO(self.data_nldelim)
        hot = hotbuf(256)

        lineidx = [0]
        def assert_lines( hot ):
            "Assert the lines we process are the ones we expect."
            self.assertEquals(str(hot), self.lines_nldelim[lineidx[0]])
            lineidx[0] += 1

        self.parse_newline_delim(hot, inp.read, assert_lines)

    data_nldelim = """
Most programming languages, including Lisp, are organized
around computing the values of mathematical
functions. Expression-oriented languages (such as Lisp,
Fortran, and Algol) capitalize on the ``pun'' that an
expression that describes the value of a function may also
be interpreted as a means of computing that value. Because
of this, most programming languages are strongly biased
toward unidirectional computations (computations with
well-defined inputs and outputs). There are, however,
radically different programming languages that relax this
bias. We saw one such example in section 3.3.5, where the
objects of computation were arithmetic constraints. In a
constraint system the direction and the order of
computation are not so well specified; in carrying out a
computation the system must therefore provide more detailed
``how to'' knowledge than would be the case with an
ordinary arithmetic computation.

This does not mean, however, that the user is released
altogether from the responsibility of providing imperative
knowledge. There are many constraint networks that
implement the same set of constraints, and the user must
choose from the set of mathematically equivalent networks a
suitable network to specify a particular computation."""

    lines_nldelim = map(str.strip, data_nldelim.splitlines())


    #---------------------------------------------------------------------------

    def parse_netstrings( self, hot, read, process_msg ):
        """
        Use case for netstrings.
        """
        # Initiallly put some data into the buffer.
        hot.putstr(read(len(hot)))
        hot.flip()

        # Loop over the entire input.
        while 1:
            # Save the current limit.
            abslimit = hot.limit

            # Loop over all the messages in the current buffer.
            while hot:
                # Read the length and parse the message.
                # No error can occur here, since we're hot.
                length = hot.getbyte() 
                if len(hot) < length:
                    # Rollback the length byte and exit the loop to fill
                    # the buffer with new data.
                    hot.position -= 1 # advance(-1)
                    break

                # Window around the message content.
                limit = hot.limit = hot.position + length

                # Parse the message.
                #
                # - We are insured to be able to read all the message
                #   here because we checked for the length.
                # - Exceptions will be programming errors.
                # - You never need to deal with rollback of your transactions.

                process_msg(hot)

                # Advance beyond the message.
                hot.position = limit
                hot.limit = abslimit

            # Compact and read the next chunk of the buffer.
            hot.compact()
            s = read(len(hot))
            if not s:
                hot.flip()
                break # Finished the input, exit.
            hot.putstr(s)
            hot.flip()

        # Process a truncated message.  Maybe pass a truncated=1 flag?
        if hot:
            process_msg(hot)

    def test_netstrings( self ):
        """
        Test for parsing netstrings.
        """
        inp = StringIO(self.packed_netstrings)
        hot = hotbuf(256)

        msgidx = [0]
        def assert_msg( hot ):
            "Assert the messages we process are the ones we expect."
            msg = str(hot)
            l = len(hot)
            msgtype = chr(hot.getbyte())
            expected = self.expected_messages[msgidx[0]]
            self.assertEquals(msg, expected)
            msgidx[0] = (msgidx[0] + 1) % len(self.data_netstrings)
            

        self.parse_netstrings(hot, inp.read, assert_msg)

    #
    # Test data for netstrings.
    #

    # Formats for packing/unpacking.
    data_fmts = dict( (x[0], Struct(x[1])) for x in (('A', 'h l f'),
                                                     ('B', 'd d d'),
                                                     ('C', 'I L l c')) )

    # Expected data.
    data_netstrings = (
        ('A', (47, 23, 3.14159217)),
        ('B', (1.23, 4.232, 6.433)),
        ('A', (43, 239, 4.243232)),
        ('C', (100, 101L, 12, 'b')),
        ('B', (22.3232, 5.343, 4.3323)),
        )

    # Test data.
    expected_messages, packed_netstrings = [], ''
    for i in xrange(100):
        for msgtype, data in data_netstrings:
            stru = data_fmts[msgtype]
            msg = pack('b', ord(msgtype)) + stru.pack(*data)
            expected_messages.append(msg)
            netstring = pack('b', len(msg)) + msg
            packed_netstrings += netstring




    #---------------------------------------------------------------------------
    
    def _test_detect_boundary( self ):
        """
        Use case for arbitraty formats, where we do not set a limit for
        the processing window, and where attempting to get bytes outside
        the boundary will result in getting more data from the input.

        This implies that the processing code should be able to rollback
        and prepare to reprocess a partially processed item in case this
        happens.
        """
        while 1:
            # Catch when we hit the boundary.
            try:
                # Loop over all the messages in the current buffer.
                while hot:
                    # Save the current window in case of error
                    mark_position = hot.position
                    mark_limit = hot.limit
                    saved = True

                    # Parse the message.
                    #
                    # - We are insured to be able to read all the message
                    #   here because we checked for the length.
                    # - Exceptions will be programming errors.
                    # - You never need to deal with rollback of your
                    #   transactions.

                    # (your code)

                    # Pop the saved window
                else:
                    raise hotbuf.BoundaryError

            except hotbuf.BoundaryError:
                # Rollback the failed transaction, if there was one.
                if saved:
                    hot.position = mark_position
                    hot.limit = mark_limit

                # Compact and read the next chunk of the buffer.
                hot.compact()
                s = read(len(hot))
                if not s:
                    break # Finished the input, exit.
                hot.putstr(s)
                hot.flip()

                

    #---------------------------------------------------------------------------

    def _test_multiple_sockets( self ):
        """
        Use case for an event-based dispatcher, that may read its input
        from multiple sockets.
        """
## FIXME TODO






#------------------------------------------------------------------------
#
def test_main():
    test_support.run_unittest(HotbufTestCase)
    test_support.run_unittest(HotbufUseCases)

if __name__ == "__main__":
    test_main()

