#! /usr/bin/env python

"""Add type checks to running code."""

import compiler, sys, os
from compiler.ast import *
from compiler import pycodegen
import dis, traceback
import glob
import string, types
import re
from declparser import parse_declarations
import parser

def striplineno(node):
    """Strip line numbers from generated code so that they don't
        interfere with error messages."""
    if isinstance(node, Node):
        node.lineno = None
        for child in node.getChildren():
            striplineno(child)

def gencode(template):
    """Generate an AST fragment from Python text"""
    module = compiler.parse(template)
    striplineno(module.node)
    return module.node.nodes


# this is the code that we insert at the top of every function
# it calls the paramcheck function with the locals() dictionary and
# a type dictionary that is inserted in the code

checkparam_code = gencode("""

try: __paramcheck__(locals(), __types__)
except InterfaceError, e:
    try: __typewarn__(e)
    except InterfaceError: raise  # if it kicks it back to me...
""")[0]

# we need to import the magic functions at the top of every file.
# should we really do the "import *"?
# I could argue that type objects should be globally available just
# as exception objects are.
import_code = gencode("""
try:
    from typecheck import __paramcheck__, __rccheck__, __declare__, __typewarn__

    # one day I'll change this to list the builtin types
    from typecheck import *
except ImportError, e:
    print "Error importing typecheck module:", e
    def __paramcheck__(*args, **args):
        pass
    __rccheck__ = __declare__ = __typewarn__ = __paramcheck__
""") 


# idea: Collect a list of functions and declarations.
#       Then go back and annotate the functions that have 
#       associated declarations.
class FunctionFixer:
    """Print the names of all the methods

    Each visit method takes two arguments, the node and its current
    scope.  The scope is a pair representing the namespace.
    """
    def visitModule(self, module, scope=""):
        """Add boilerplate to the top, then process the body, then
         re-process every function"""
        self.functions = {}
        self.declarations = {}
        self.returns = {}
        module.node.nodes[0:0]=import_code
        self.visit(module.node, scope)

        for funcname, func in self.functions.items():
            decl = self.declarations.get(funcname, None)
            if decl:
                print "    Adding type checks to", funcname
                func.code.nodes[0:0]=[decl.ast(), checkparam_code]
                for node in self.returns[funcname]:
                    node.value = CallFunc(Name("__rccheck__"), 
                        [node.value, 
                        CallFunc(Name("locals"), [], None, None),
                        Name("__types__")
                        ], None, None)
            else:
                print "    No declaration for ", funcname

    def visitClass(self, node, scopename):
        if scopename:
            scopename += "."
        fullname = scopename+ node.name
        self.visit(node.code, fullname)
    
    def visitFunction(self, node, scopename):
        if scopename:
            scopename += "."
        fullname = scopename + node.name
        self.functions[fullname]=node #keep track of it for later
        self.returns[fullname]=[]
        self.visit(node.code, fullname)

    def visitCallFunc(self, node, scopename):
        if isinstance(node.node, Name):
            name = node.node.name
        else:
            name = None

        # XXX - figure out how to remove these rather than make them
        #       noops
        if name == "__declare__":
            assert len(node.args)==1
            const = node.args[0]
            assert const.__class__==Const
            assert type(const.value) in \
                            (types.StringType, types.UnicodeType)
            decls = parse_declarations(const.value)
            for decl in decls:
                if scopename:
                    scopename += "."
                fullname = scopename + decl.func_name
                self.declarations[fullname] = decl
        self.visit(node.node, scopename)

    def visitReturn(self, node, funcname):
        self.returns[funcname].append(node)

class MyModuleCompiler(pycodegen.Module):
    def __init__(self, ast, filename):
        self.filename = filename
        self.ast = ast
        self.code = None

    def compile(self, display=0):
        root, filename = os.path.split(self.filename)
        gen = pycodegen.ModuleCodeGenerator(filename)
        pycodegen.walk(self.ast, gen, 1)
        if display:
            import pprint
            print pprint.pprint(ast)
        self.code = gen.getCode()

def addchecks(file, typefile=None, outfile=None, gendisfile=0):
    if not typefile:
        typefile = file +"t"
    if not outfile:
        outfile = file + "c"

    functionfixer = FunctionFixer()
    f = open(typefile)
    buf = f.read()
    typeast = compiler.parse(buf)

    f.close()

    f = open(file)
    buf = f.read()
    srcast = compiler.parse(buf)
    f.close()

    # merge ASTs
    ast = typeast
    ast.node.nodes.append(srcast.node.nodes)
    ast.doc = srcast.doc

    # mutates AST
    compiler.walk(ast, functionfixer)
    
    gen = MyModuleCompiler(ast, file)

    gen.compile(0)

    if gendisfile:
        import dis
        import sys
        ostdout = sys.stdout
        sys.stdout = open(file+".dis", "w")
        dis.dis(gen.code)
        sys.stdout = ostdout

    f = open(outfile, "wb")
    gen.dump(f)
    f.close()
            
def main(files):
    for file in files:
        typefile = file + "t"
        if not os.path.exists(typefile):
            print "No check file for", file
            continue

        checkfile = file + "c"

        print "Adding checks to", checkfile
        
        try:
            addchecks(file, typefile, checkfile)
        except:
            traceback.print_exc()

def __declare__(self):
    "no-op operation for highlighting declarations"

if __name__ == "__main__":
    import sys
    files=[]
    for arg in sys.argv[1:]:
        files.extend(glob.glob(arg))
    
    main(files)

