# -*- coding: utf-8 -*-
"""
    sphinx.textwriter
    ~~~~~~~~~~~~~~~~~

    Custom docutils writer for plain text.

    :copyright: 2008 by Georg Brandl.
    :license: BSD.
"""

import re
import textwrap

from docutils import nodes, writers

from sphinx import addnodes


class TextWriter(writers.Writer):
    supported = ('text',)
    settings_spec = ('No options here.', '', ())
    settings_defaults = {}

    output = None

    def __init__(self, builder):
        writers.Writer.__init__(self)
        self.builder = builder

    def translate(self):
        visitor = TextTranslator(self.document, self.builder)
        self.document.walkabout(visitor)
        self.output = visitor.body

# monkey-patch...
new_wordsep_re = re.compile(
        r'(\s+|'                                  # any whitespace
        r'(?<=\s)(?::[a-z-]+:)?`\S+|'             # interpreted text start
        r'[^\s\w]*\w+[a-zA-Z]-(?=\w+[a-zA-Z])|'   # hyphenated words
        r'(?<=[\w\!\"\'\&\.\,\?])-{2,}(?=\w))')   # em-dash
textwrap.TextWrapper.wordsep_re = new_wordsep_re

MAXWIDTH = 70
STDINDENT = 3


class TextTranslator(nodes.NodeVisitor):
    sectionchars = '*=-~"+'

    def __init__(self, document, builder):
        nodes.NodeVisitor.__init__(self, document)

        self.states = [[]]
        self.stateindent = [0]
        self.sectionlevel = 0
        self.table = None

    def add_text(self, text):
        self.states[-1].append((-1, text))
    def new_state(self, indent=STDINDENT):
        self.states.append([])
        self.stateindent.append(indent)
    def end_state(self, wrap=True, end=[''], first=None):
        content = self.states.pop()
        maxindent = sum(self.stateindent)
        indent = self.stateindent.pop()
        result = []
        toformat = []
        def do_format():
            if not toformat:
                return
            if wrap:
                res = textwrap.wrap(''.join(toformat), width=MAXWIDTH-maxindent)
            else:
                res = ''.join(toformat).splitlines()
            if end:
                res += end
            result.append((indent, res))
        for itemindent, item in content:
            if itemindent == -1:
                toformat.append(item)
            else:
                do_format()
                result.append((indent + itemindent, item))
                toformat = []
        do_format()
        if first is not None and result:
            itemindent, item = result[0]
            if item:
                result.insert(0, (itemindent - indent, [first + item[0]]))
                result[1] = (itemindent, item[1:])
        self.states[-1].extend(result)

    def visit_document(self, node):
        self.new_state(0)
    def depart_document(self, node):
        self.end_state()
        self.body = '\n'.join(line and (' '*indent + line)
                              for indent, lines in self.states[0]
                              for line in lines)
        # XXX header/footer?

    def visit_highlightlang(self, node):
        raise nodes.SkipNode

    def visit_section(self, node):
        self._title_char = self.sectionchars[self.sectionlevel]
        self.sectionlevel += 1
    def depart_section(self, node):
        self.sectionlevel -= 1

    def visit_topic(self, node):
        self.new_state(0)
    def depart_topic(self, node):
        self.end_state()

    visit_sidebar = visit_topic
    depart_sidebar = depart_topic

    def visit_rubric(self, node):
        self.new_state(0)
        self.add_text('-[ ')
    def depart_rubric(self, node):
        self.add_text(' ]-')
        self.end_state()

    def visit_glossary(self, node):
        pass
    def depart_glossary(self, node):
        pass

    def visit_title(self, node):
        if isinstance(node.parent, nodes.Admonition):
            self.add_text(node.astext()+': ')
            raise nodes.SkipNode
        self.new_state(0)
    def depart_title(self, node):
        if isinstance(node.parent, nodes.section):
            char = self._title_char
        else:
            char = '^'
        text = ''.join(x[1] for x in self.states.pop() if x[0] == -1)
        self.stateindent.pop()
        self.states[-1].append((0, ['', text, '%s' % (char * len(text)), '']))

    def visit_module(self, node):
        if node.has_key('platform'):
            self.new_state(0)
            self.add_text('Platform: %s' % node['platform'])
            self.end_state()
        raise nodes.SkipNode

    def visit_desc(self, node):
        pass
    def depart_desc(self, node):
        pass

    def visit_desc_signature(self, node):
        self.new_state(0)
        if node.parent['desctype'] in ('class', 'exception'):
            self.add_text('%s ' % node.parent['desctype'])
    def depart_desc_signature(self, node):
        # XXX: wrap signatures in a way that makes sense
        self.end_state(wrap=False, end=None)

    def visit_desc_name(self, node):
        pass
    def depart_desc_name(self, node):
        pass

    def visit_desc_addname(self, node):
        pass
    def depart_desc_addname(self, node):
        pass

    def visit_desc_type(self, node):
        pass
    def depart_desc_type(self, node):
        pass

    def visit_desc_parameterlist(self, node):
        self.add_text('(')
        self.first_param = 1
    def depart_desc_parameterlist(self, node):
        self.add_text(')')

    def visit_desc_parameter(self, node):
        if not self.first_param:
            self.add_text(', ')
        else:
            self.first_param = 0
        self.add_text(node.astext())
        raise nodes.SkipNode

    def visit_desc_optional(self, node):
        self.add_text('[')
    def depart_desc_optional(self, node):
        self.add_text(']')

    def visit_refcount(self, node):
        pass
    def depart_refcount(self, node):
        pass

    def visit_desc_content(self, node):
        self.new_state()
        self.add_text('\n')
    def depart_desc_content(self, node):
        self.end_state()

    def visit_figure(self, node):
        self.new_state()
    def depart_figure(self, node):
        self.end_state()

    def visit_caption(self, node):
        pass
    def depart_caption(self, node):
        pass

    def visit_productionlist(self, node):
        self.new_state()
        names = []
        for production in node:
            names.append(production['tokenname'])
        maxlen = max(len(name) for name in names)
        for production in node:
            if production['tokenname']:
                self.add_text(production['tokenname'].ljust(maxlen) + ' ::=')
                lastname = production['tokenname']
            else:
                self.add_text('%s    ' % (' '*len(lastname)))
            self.add_text(production.astext() + '\n')
        self.end_state(wrap=False)
        raise nodes.SkipNode

    def visit_seealso(self, node):
        self.new_state()
    def depart_seealso(self, node):
        self.end_state(first='')

    def visit_footnote(self, node):
        self._footnote = node.children[0].astext().strip()
        self.new_state(len(self._footnote) + 3)
    def depart_footnote(self, node):
        self.end_state(first='[%s] ' % self._footnote)

    def visit_label(self, node):
        raise nodes.SkipNode

    def visit_tabular_col_spec(self, node):
        raise nodes.SkipNode

    def visit_colspec(self, node):
        self.table[0].append(node['colwidth'])
        raise nodes.SkipNode

    def visit_tgroup(self, node):
        pass
    def depart_tgroup(self, node):
        pass

    def visit_thead(self, node):
        pass
    def depart_thead(self, node):
        pass

    def visit_tbody(self, node):
        self.table.append('sep')
    def depart_tbody(self, node):
        pass

    def visit_row(self, node):
        self.table.append([])
    def depart_row(self, node):
        pass

    def visit_entry(self, node):
        if node.has_key('morerows') or node.has_key('morecols'):
            raise NotImplementedError('Column or row spanning cells are '
                                      'not implemented.')
        self.new_state(0)
    def depart_entry(self, node):
        text = '\n'.join('\n'.join(x[1]) for x in self.states.pop())
        self.stateindent.pop()
        self.table[-1].append(text)

    def visit_table(self, node):
        if self.table:
            raise NotImplementedError('Nested tables are not supported.')
        self.new_state(0)
        self.table = [[]]
    def depart_table(self, node):
        lines = self.table[1:]
        fmted_rows = []
        colwidths = self.table[0]
        realwidths = colwidths[:]
        separator = 0
        # don't allow paragraphs in table cells for now
        for line in lines:
            if line == 'sep':
                separator = len(fmted_rows)
            else:
                cells = []
                for i, cell in enumerate(line):
                    par = textwrap.wrap(cell, width=colwidths[i])
                    if par:
                        maxwidth = max(map(len, par))
                    else:
                        maxwidth = 0
                    realwidths[i] = max(realwidths[i], maxwidth)
                    cells.append(par)
                fmted_rows.append(cells)

        def writesep(char='-'):
            out = ['+']
            for width in realwidths:
                out.append(char * (width+2))
                out.append('+')
            self.add_text(''.join(out) + '\n')

        def writerow(row):
            lines = map(None, *row)
            for line in lines:
                out = ['|']
                for i, cell in enumerate(line):
                    if cell:
                        out.append(' ' + cell.ljust(realwidths[i]+1))
                    else:
                        out.append(' ' * (realwidths[i] + 2))
                    out.append('|')
                self.add_text(''.join(out) + '\n')

        for i, row in enumerate(fmted_rows):
            if separator and i == separator:
                writesep('=')
            else:
                writesep('-')
            writerow(row)
        writesep('-')
        self.table = None
        self.end_state(wrap=False)

    def visit_acks(self, node):
        self.new_state(0)
        self.add_text(', '.join(n.astext() for n in node.children[0].children) + '.')
        self.end_state()
        raise nodes.SkipNode

    def visit_image(self, node):
        self.add_text('[image]')
        raise nodes.SkipNode

    def visit_transition(self, node):
        indent = sum(self.stateindent)
        self.new_state(0)
        self.add_text('=' * (MAXWIDTH - indent))
        self.end_state()
        raise nodes.SkipNode

    def visit_bullet_list(self, node):
        self._list_counter = -1
    def depart_bullet_list(self, node):
        pass

    def visit_enumerated_list(self, node):
        self._list_counter = 0
    def depart_enumerated_list(self, node):
        pass

    def visit_definition_list(self, node):
        self._list_counter = -2
    def depart_definition_list(self, node):
        pass

    def visit_list_item(self, node):
        if self._list_counter == -1:
            # bullet list
            self.new_state(2)
        elif self._list_counter == -2:
            # definition list
            pass
        else:
            # enumerated list
            self._list_counter += 1
            self.new_state(len(str(self._list_counter)) + 2)
    def depart_list_item(self, node):
        if self._list_counter == -1:
            self.end_state(first='* ', end=None)
        elif self._list_counter == -2:
            pass
        else:
            self.end_state(first='%s. ' % self._list_counter, end=None)

    def visit_definition_list_item(self, node):
        self._li_has_classifier = len(node) >= 2 and \
                                  isinstance(node[1], nodes.classifier)
    def depart_definition_list_item(self, node):
        pass

    def visit_term(self, node):
        self.new_state(0)
    def depart_term(self, node):
        if not self._li_has_classifier:
            self.end_state(end=None)

    def visit_classifier(self, node):
        self.add_text(' : ')
    def depart_classifier(self, node):
        self.end_state(end=None)

    def visit_definition(self, node):
        self.new_state()
    def depart_definition(self, node):
        self.end_state()

    def visit_field_list(self, node):
        pass
    def depart_field_list(self, node):
        pass

    def visit_field(self, node):
        pass
    def depart_field(self, node):
        pass

    def visit_field_name(self, node):
        self.new_state(0)
    def depart_field_name(self, node):
        self.add_text(':')
        self.end_state(end=None)

    def visit_field_body(self, node):
        self.new_state()
    def depart_field_body(self, node):
        self.end_state()

    def visit_centered(self, node):
        pass
    def depart_centered(self, node):
        pass

    def visit_admonition(self, node):
        self.new_state(0)
    def depart_admonition(self, node):
        self.end_state()

    def _make_visit_admonition(name):
        def visit_admonition(self, node):
            self.new_state(2)
        return visit_admonition
    def _make_depart_admonition(name):
        def depart_admonition(self, node):
            self.end_state(first=name.capitalize() + ': ')
        return depart_admonition

    visit_attention = _make_visit_admonition('attention')
    depart_attention = _make_depart_admonition('attention')
    visit_caution = _make_visit_admonition('caution')
    depart_caution = _make_depart_admonition('caution')
    visit_danger = _make_visit_admonition('danger')
    depart_danger = _make_depart_admonition('danger')
    visit_error = _make_visit_admonition('error')
    depart_error = _make_depart_admonition('error')
    visit_hint = _make_visit_admonition('hint')
    depart_hint = _make_depart_admonition('hint')
    visit_important = _make_visit_admonition('important')
    depart_important = _make_depart_admonition('important')
    visit_note = _make_visit_admonition('note')
    depart_note = _make_depart_admonition('note')
    visit_tip = _make_visit_admonition('tip')
    depart_tip = _make_depart_admonition('tip')
    visit_warning = _make_visit_admonition('warning')
    depart_warning = _make_depart_admonition('warning')

    def visit_versionmodified(self, node):
        self.new_state(0)
        if node['type'] == 'versionadded':
            tmpl = 'Added in version %s'
        elif node['type'] == 'versionchanged':
            tmpl = 'Changed in version %s'
        elif node['type'] == 'deprecated':
            tmpl = 'Deprecated in version %s'
        if node.children:
            self.add_text(tmpl % node['version'] + ': ')
        else:
            self.add_text(tmpl % node['version'] + '.')
    def depart_versionmodified(self, node):
        self.end_state()

    def visit_literal_block(self, node):
        self.new_state()
    def depart_literal_block(self, node):
        self.end_state(wrap=False)

    def visit_doctest_block(self, node):
        self.new_state(0)
    def depart_doctest_block(self, node):
        self.end_state(wrap=False)

    def visit_line_block(self, node):
        self.new_state(0)
    def depart_line_block(self, node):
        self.end_state(wrap=False)

    def visit_line(self, node):
        pass
    def depart_line(self, node):
        pass

    def visit_block_quote(self, node):
        self.new_state()
    def depart_block_quote(self, node):
        self.end_state()

    def visit_compact_paragraph(self, node):
        pass
    def depart_compact_paragraph(self, node):
        pass

    def visit_paragraph(self, node):
        if not isinstance(node.parent, nodes.Admonition) or \
               isinstance(node.parent, addnodes.seealso):
            self.new_state(0)
    def depart_paragraph(self, node):
        if not isinstance(node.parent, nodes.Admonition) or \
               isinstance(node.parent, addnodes.seealso):
            self.end_state()

    def visit_target(self, node):
        raise nodes.SkipNode

    def visit_index(self, node):
        raise nodes.SkipNode

    def visit_substitution_definition(self, node):
        raise nodes.SkipNode

    def visit_pending_xref(self, node):
        pass
    def depart_pending_xref(self, node):
        pass

    def visit_reference(self, node):
        pass
    def depart_reference(self, node):
        pass

    def visit_emphasis(self, node):
        self.add_text('*')
    def depart_emphasis(self, node):
        self.add_text('*')

    def visit_literal_emphasis(self, node):
        self.add_text('*')
    def depart_literal_emphasis(self, node):
        self.add_text('*')

    def visit_strong(self, node):
        self.add_text('**')
    def depart_strong(self, node):
        self.add_text('**')

    def visit_title_reference(self, node):
        self.add_text('*')
    def depart_title_reference(self, node):
        self.add_text('*')

    def visit_literal(self, node):
        self.add_text('``')
    def depart_literal(self, node):
        self.add_text('``')

    def visit_subscript(self, node):
        self.add_text('_')
    def depart_subscript(self, node):
        pass

    def visit_superscript(self, node):
        self.add_text('^')
    def depart_superscript(self, node):
        pass

    def visit_footnote_reference(self, node):
        self.add_text('[%s]' % node.astext())
        raise nodes.SkipNode

    def visit_Text(self, node):
        self.add_text(node.astext())
    def depart_Text(self, node):
        pass

    def visit_problematic(self, node):
        self.add_text('>>')
    def depart_problematic(self, node):
        self.add_text('<<')

    def visit_system_message(self, node):
        self.new_state(0)
        self.add_text('<SYSTEM MESSAGE: %s>' % node.astext())
        self.end_state()
        raise nodes.SkipNode

    def visit_comment(self, node):
        raise nodes.SkipNode

    def unknown_visit(self, node):
        raise NotImplementedError("Unknown node: " + node.__class__.__name__)
