"""Attach raw tokens to AST to generate a "concrete" AST."""

import _ast
import tokenize
from StringIO import StringIO

# Token eating rules for AST nodes

# TODO(jhylton): The rules are too simplistic now.  They need to express:
#  -- repeated node types, e.g. multi-statement suites
#  -- optional elements, e.g. the else in an if/else statement
#  -- parentheses, but maybe there's a special case that says parenthesis
#     can only be around expressions
#
# There is never any question about whether optional elements are present.
# If the AST node for an If statement has no orlese, there is no else token.
# So the token rules may need to associate tokens with fields.

class TokenRules:
    Module = [_ast.stmt, _ast.stmt, tokenize.ENDMARKER]
    
    If = [(tokenize.NAME, 'if'), _ast.expr, (tokenize.OP, ':'),
          tokenize.NEWLINE,
          tokenize.INDENT,
          _ast.stmt,
          tokenize.DEDENT,
          (tokenize.NAME, 'else'), (tokenize.OP, ':'), tokenize.NEWLINE,
          tokenize.INDENT,
          _ast.stmt,
          tokenize.DEDENT,
          ]
    Assign = [_ast.expr, (tokenize.OP, '='), _ast.expr, tokenize.NEWLINE]
    Print = [(tokenize.NAME, 'print'), _ast.expr, tokenize.NEWLINE]

    BinOp = [_ast.expr, _ast.operator, _ast.expr]
    Compare = [_ast.expr, _ast.operator, _ast.expr]
    Attribute = [_ast.expr, (tokenize.OP, '.'), tokenize.NAME]
    Call = [_ast.expr, (tokenize.OP, '('), _ast.expr,
            (tokenize.OP, ')')]

    Str = [tokenize.STRING]
    Name = [tokenize.NAME]
    Num = [tokenize.NUMBER]
    Add = [(tokenize.OP, '+')]
    Eq = [(tokenize.OP, '==')]
    keyword = [tokenize.NAME, (tokenize.OP, '='), _ast.expr]

def TreeIter(tree):
    """Return all the AST nodes in tree in order."""
    yield tree
    if tree._fields is None:
        return
    for fieldname in tree._fields:
        child = getattr(tree, fieldname)
        if isinstance(child, _ast.AST):
            for node in TreeIter(child):
                yield node
        elif isinstance(child, list):
            for node in child:
                for n2 in TreeIter(node):
                    yield n2

def WalkConcrete(tree):
    for elt in tree.concrete:
        if isinstance(elt, _ast.AST):
            for child in WalkConcrete(elt):
                yield child
        else:
            yield elt # a raw token

class TreeTokenizer:
    """Decorate AST nodes with the actual tokens that represent them.

    The TreeTokenizer is a state machine that can perform the
    following operations:

    match: The current token matches the next token to consume for
           the current ast node.

    walk: The current tree node does not have a rule to match
          a terminal.  Walk to the next node of the tree.

    .concrete: Creates a parallel tree with tokens & AST nodes
    
    """

    DEBUG = 0

    def __init__(self, tree, tokens):
        self.root = tree
        self.tokens = tokens
        self.nodes = TreeIter(tree)

        # Initialize state of the matching engine
        self.next_node()
        self.backup_node = None
        self.token = self.tokens.next()
        self.stack = []

        # We manage two parallel stacks in order to match parens
        # to the expressions they surround.
        self.parens = []
        self.exprs = []

    def lookup_matches(self, node):
        """Return the token matching rules for ast node."""
        name = node.__class__.__name__
        rules = TokenRules.__dict__[name]
        return list(rules) # return a copy

    def __str__(self):
        tok_type = tokenize.tok_name[self.token[0]]
        token = self.token[1]
        node = self.node.__class__.__name__
        nid = hex(id(node))
        matches = " ".join([str(m) for m in self.matches])
        stack = self.stack
        return ("   TreeTokenizer state\n"
                "token=%(tok_type)s:%(token)r node=%(node)s:%(nid)s\n"
                "matches=%(matches)s\n"
                "stack=%(stack)s\n"
                % locals())

    def next_node(self):
        self.node = self.nodes.next()
        # There is no concrete syntax corresponding to an expression
        # context, so skip it right here.
        # TODO(jhylton): Think about whether this is right.
        while isinstance(self.node, _ast.expr_context):
            self.node = self.nodes.next()
        self.matches = self.lookup_matches(self.node)
        if not hasattr(self.node, "concrete"):
            self.node.concrete = []

    def consume_token(self):
        self.node.concrete.append(self.token)
        del self.matches[0]
        try:
            self.token = self.tokens.next()
            return True
        except StopIteration:
            return False

    def consume_node(self):
        self.node.concrete.append(self.backup_node)
        del self.matches[0]
        self.backup_node = None

    def is_node(self, next_match):
        return (isinstance(next_match, type) and
                issubclass(next_match, _ast.AST))

    def backup(self):
        if self.DEBUG:
            print "BACKUP"
        self.backup_node = self.node
        self.node, self.matches = self.stack.pop()

    def traverse(self):
        if self.DEBUG:
            print "TRAVERSE"
        self.stack.append((self.node, self.matches))
        self.next_node()

    def match1(self, next_match):
        if self.DEBUG:
            print "MATCH 1"
        assert next_match == self.token[0], (next_match, self.token)
        return self.consume_token()

    def match2(self, next_match):
        if self.DEBUG:
            print "MATCH 2"
        token_type, token_value = next_match
        assert token_type == self.token[0]
        assert token_value == self.token[1]
        return self.consume_token()

    def step(self):
        if self.DEBUG:
            print self
        
        if not self.matches:
            self.backup()
            return True

        # Check whether the next match is for a token or a node.
        next_match = self.matches[0]
        if self.is_node(next_match):
            if self.backup_node is not None:
                self.consume_node()
                return True
            else:
                self.traverse()
                return True

        if isinstance(next_match, tuple):  # specific tokens like 'def'
            return self.match2(next_match)

        if isinstance(next_match, int):   # generic tokens like a var name
            return self.match1(next_match)

        # TODO(jhylton): Figure out if you can ever get here.

        return False

    def run(self):
        while 1:
            if not self.step():
                break

def parse(source, file="<string>", mode="exec"):
    """Parse the code in source and return a concrete AST."""
    tokens = tokenize.generate_tokens(StringIO(source).readline)
    tree = compile(source, file, mode, 0x400)
    TreeTokenizer(tree, tokens).run()
    return tree

def unparse(tree):
    """Return source code generated from a concrete AST."""
    return tokenize.untokenize(WalkConcrete(tree))

if __name__ == "__main__":
    import sys
    import StringIO
    
    path = sys.argv[1]
    source = open(path).read()
    tokens = tokenize.generate_tokens(open(path).readline)
    tree = compile(source, path, "exec", 0x400)
    for node in TreeIter(tree):
        print node
    tt = TreeTokenizer(tree, tokens)
    tt.run()
    code = tokenize.untokenize(WalkConcrete(tree))
    print code
