#! /usr/bin/env python """Add type checks to running code.""" import compiler, sys, os from compiler.ast import * from compiler import pycodegen import dis, traceback import glob import string, types import re from declparser import parse_declarations import parser def striplineno(node): """Strip line numbers from generated code so that they don't interfere with error messages.""" if isinstance(node, Node): node.lineno = None for child in node.getChildren(): striplineno(child) def gencode(template): """Generate an AST fragment from Python text""" module = compiler.parse(template) striplineno(module.node) return module.node.nodes # this is the code that we insert at the top of every function # it calls the paramcheck function with the locals() dictionary and # a type dictionary that is inserted in the code checkparam_code = gencode(""" try: __paramcheck__(locals(), __types__) except InterfaceError, e: try: __typewarn__(e) except InterfaceError: raise # if it kicks it back to me... """)[0] # we need to import the magic functions at the top of every file. # should we really do the "import *"? # I could argue that type objects should be globally available just # as exception objects are. import_code = gencode(""" try: from typecheck import __paramcheck__, __rccheck__, __declare__, __typewarn__ # one day I'll change this to list the builtin types from typecheck import * except ImportError, e: print "Error importing typecheck module:", e def __paramcheck__(*args, **args): pass __rccheck__ = __declare__ = __typewarn__ = __paramcheck__ """) # idea: Collect a list of functions and declarations. # Then go back and annotate the functions that have # associated declarations. class FunctionFixer: """Print the names of all the methods Each visit method takes two arguments, the node and its current scope. The scope is a pair representing the namespace. """ def visitModule(self, module, scope=""): """Add boilerplate to the top, then process the body, then re-process every function""" self.functions = {} self.declarations = {} self.returns = {} module.node.nodes[0:0]=import_code self.visit(module.node, scope) for funcname, func in self.functions.items(): decl = self.declarations.get(funcname, None) if decl: print " Adding type checks to", funcname func.code.nodes[0:0]=[decl.ast(), checkparam_code] for node in self.returns[funcname]: node.value = CallFunc(Name("__rccheck__"), [node.value, CallFunc(Name("locals"), [], None, None), Name("__types__") ], None, None) else: print " No declaration for ", funcname def visitClass(self, node, scopename): if scopename: scopename += "." fullname = scopename+ node.name self.visit(node.code, fullname) def visitFunction(self, node, scopename): if scopename: scopename += "." fullname = scopename + node.name self.functions[fullname]=node #keep track of it for later self.returns[fullname]=[] self.visit(node.code, fullname) def visitCallFunc(self, node, scopename): if isinstance(node.node, Name): name = node.node.name else: name = None # XXX - figure out how to remove these rather than make them # noops if name == "__declare__": assert len(node.args)==1 const = node.args[0] assert const.__class__==Const assert type(const.value) in \ (types.StringType, types.UnicodeType) decls = parse_declarations(const.value) for decl in decls: if scopename: scopename += "." fullname = scopename + decl.func_name self.declarations[fullname] = decl self.visit(node.node, scopename) def visitReturn(self, node, funcname): self.returns[funcname].append(node) class MyModuleCompiler(pycodegen.Module): def __init__(self, ast, filename): self.filename = filename self.ast = ast self.code = None def compile(self, display=0): root, filename = os.path.split(self.filename) gen = pycodegen.ModuleCodeGenerator(filename) pycodegen.walk(self.ast, gen, 1) if display: import pprint print pprint.pprint(ast) self.code = gen.getCode() def addchecks(file, typefile=None, outfile=None, gendisfile=0): if not typefile: typefile = file +"t" if not outfile: outfile = file + "c" functionfixer = FunctionFixer() f = open(typefile) buf = f.read() typeast = compiler.parse(buf) f.close() f = open(file) buf = f.read() srcast = compiler.parse(buf) f.close() # merge ASTs ast = typeast ast.node.nodes.append(srcast.node.nodes) ast.doc = srcast.doc # mutates AST compiler.walk(ast, functionfixer) gen = MyModuleCompiler(ast, file) gen.compile(0) if gendisfile: import dis import sys ostdout = sys.stdout sys.stdout = open(file+".dis", "w") dis.dis(gen.code) sys.stdout = ostdout f = open(outfile, "wb") gen.dump(f) f.close() def main(files): for file in files: typefile = file + "t" if not os.path.exists(typefile): print "No check file for", file continue checkfile = file + "c" print "Adding checks to", checkfile try: addchecks(file, typefile, checkfile) except: traceback.print_exc() def __declare__(self): "no-op operation for highlighting declarations" if __name__ == "__main__": import sys files=[] for arg in sys.argv[1:]: files.extend(glob.glob(arg)) main(files)