"""Replace has_key with in operator."""

import _ast
import concrete

def Replace(tree, old, new):
    """Replace the node old in tree with the node new.

    Replaces only the first occurrence in old.
    """
    for node in concrete.TreeIter(tree):
        if node._fields is None:
            continue
        if old in node.concrete:
            i = node.concrete.index(old)
            node.concrete[i] = new
        for fieldname in node._fields:
            child = getattr(node, fieldname)
            if isinstance(child, list):
                if old in child:
                    i = child.index(old)
                    child[i] = new
                    return
            else:
                if child is old:
                    setattr(node, fieldname, new)
                    return

def strip_positions(node):
    new_concrete = []
    for elt in node.concrete:
        if isinstance(elt, _ast.AST):
            strip_positions(elt)
            new_concrete.append(elt)
        else:
            new_concrete.append(elt[:2])
    node.concrete = new_concrete

class HasKeyTransformer:
    
    def search(self, tree):
        for node in concrete.TreeIter(tree):
            if self.is_has_key_call(node):
                yield node

    def is_has_key_call(self, node):
        if not isinstance(node, _ast.Call):
            return False
        callee = node.func
        if not isinstance(callee, _ast.Attribute):
            return False
        return callee.attr == "has_key"

    def replace(self, node):
        # Replace the Call node with a Compare node.  The
        # base of the callee becomes the RHS and one argument
        # in the call becomes the LHS.

        the_dict = node.func.value
        the_key = node.args[0]
        strip_positions(the_key)
        
        new = _ast.Compare()
        new.left = the_key
        the_in = _ast.In()

        # We need to synthesize a full token for "in".  Yuck!
        the_in.concrete = [(3, "in")]
        new.ops = [the_in]
        new.comparators = [the_dict]
        new.concrete = [new.left] + new.ops + new.comparators
        return new
        
def main(args):
    path = args[0]
    tree = concrete.parse(open(path).read(), path)
    trans = HasKeyTransformer()
    for match in trans.search(tree):
        replacement = trans.replace(match)
        Replace(tree, match, replacement)
        print "rep", concrete.unparse(replacement)

    print concrete.unparse(tree)

if __name__ == "__main__":
    import sys
    main(sys.argv[1:])
