"""Get useful information from live Python objects.

This module encapsulates the interface provided by the internal special
attributes (func_*, co_*, im_*, tb_*, etc.) in a friendlier fashion.
It also provides some help for examining source code and class layout.

Here are some of the useful functions provided by this module:

    getdoc(), getcomments() - get documentation on an object
    getclasstree() - arrange classes so as to represent their hierarchy
    getfile(), getsourcefile(), getsource() - find an object's source code
    getargspec(), getargvalues() - get info about function arguments
    formatargspec(), formatargvalues() - format an argument spec
    stack(), trace() - get info about frames on the stack or in a traceback
"""

# I, Ka-Ping Yee, the author of this contribution, hereby grant to anyone and
# everyone a nonexclusive, irrevocable, royalty-free, worldwide license to
# reproduce, distribute, perform and/or display publicly, prepare derivative
# versions, and otherwise use this contribution in any fashion, or any
# derivative versions thereof, at no cost to anyone, and to authorize others
# to do so.  This software is provided "as is", with NO WARRANTY WHATSOEVER,
# not even a warranty of merchantability or fitness for any particular purpose.

__version__ = "Ka-Ping Yee <ping@lfw.org>, 29 May 2000"

import sys, types, string, dis, imp

# ----------------------------------------------------------- type-checking
def ismodule(object):
    """Is the object a module with the __file__ special attribute?"""
    return type(object) is types.ModuleType

def isclass(object):
    """Is the object a class with the __module__ special attribute?"""
    return type(object) is types.ClassType

def ismethod(object):
    """Is the object a method with the im_* set of special attributes?"""
    return type(object) is types.MethodType

def isfunction(object):
    """Is the object a function with the func_* set of special attributes?"""
    return type(object) in [types.FunctionType, types.LambdaType]

def istraceback(object):
    """Is the object a traceback with the tb_* set of special attributes?"""
    return type(object) is types.TracebackType

def isframe(object):
    """Is the object a frame object with the f_* set of special attributes?"""
    return type(object) is types.FrameType

def iscode(object):
    """Is the object a code object with the co_* set of special attributes?"""
    return type(object) is types.CodeType

def isbuiltin(object):
    """Is the object a callable function providing no special attributes?"""
    return type(object) in [types.BuiltinFunctionType,
                            types.BuiltinMethodType]

def isroutine(object):
    """Is the object any of the built-in or user-defined function types?"""
    return type(object) in [types.FunctionType, types.LambdaType,
                            types.MethodType, types.BuiltinFunctionType,
                            types.BuiltinMethodType]

def getmembers(object, predicate=lambda x: 1):
    """Return all members of an object as (key, value) pairs sorted by key.
    Optionally, only return members that satisfy a given predicate."""
    results = []
    for key in object.__dict__.keys():
        value = getattr(object, key)
        if predicate(value):
            results.append((key, value))
    results.sort()
    return results

# -------------------------------------------------- source code extraction
def indentsize(line):
    """Return the indent size, in spaces, at the start of a line of text."""
    expline = string.expandtabs(line)
    return len(expline) - len(string.lstrip(expline))

def getdoc(object):
    """Get the documentation string for an object."""
    if not hasattr(object, "__doc__"):
        raise TypeError, "arg has no __doc__ attribute"
    if object.__doc__:
        lines = string.split(string.expandtabs(object.__doc__), "\n")
        margin = None
        for line in lines[1:]:
            content = len(string.lstrip(line))
            if not content: continue
            indent = len(line) - content
            if margin is None: margin = indent
            else: margin = min(margin, indent)
        if margin is not None:
            for i in range(1, len(lines)): lines[i] = lines[i][margin:]
        return string.join(lines, "\n")

def getfile(object):
    """Try to guess which (text or binary) file an object was defined in."""
    if ismodule(object):
        return imp.find_module(object.__name__)[1]
    if isclass(object):
        return imp.find_module(object.__module__)[1]
    if ismethod(object):
        object = object.im_func
    if isfunction(object):
        object = object.func_code
    if istraceback(object):
        object = object.tb_frame
    if isframe(object):
        object = object.f_code
    if iscode(object):
        return object.co_filename
    raise TypeError, "arg is not a module, class, method, " \
                     "function, traceback, frame, or code object"

def getsourcefile(object):
    """Try to guess which Python source file an object was defined in."""
    filename = getfile(object)
    if filename[-4:] == ".pyc":
        filename = filename[:-4] + ".py"
    return filename

def findsource(object):
    """Find the first line of code corresponding to a given module, class,
    method, function, traceback, frame, or code object; return the entire
    contents of the source file and the starting line number.  An IOError
    exception is raised if the source code cannot be retrieved."""
    try:
        file = open(getsourcefile(object))
        lines = file.readlines()
        file.close()
    except (TypeError, IOError):
        raise IOError, "could not get source code"

    if ismodule(object):
        return lines, 0

    if isclass(object):
        name = object.__name__
        matches = (["class", name], ["class", name + ":"])
        for i in range(len(lines)):
            if string.split(lines[i])[:2] in matches:
                return lines, i
        else: raise IOError, "could not find class definition"

    if ismethod(object):
        object = object.im_func
    if isfunction(object):
        object = object.func_code
    if istraceback(object):
        object = object.tb_frame
    if isframe(object):
        object = object.f_code
    if iscode(object):
        try:
            lnum = object.co_firstlineno
        except AttributeError: pass
        else:
            while lnum > 0:
                if string.split(lines[lnum])[:1] == ["def"]: break
                lnum = lnum - 1
            return lines, lnum
        raise IOError, "could not find function definition"
            
def getcomments(object):
    """Look for preceding lines of comments in an object's source code."""
    lines, lnum = findsource(object)

    if ismodule(object):
        # Look for a comment block at the top of the file.
        start = 0
        if lines[0][:2] == "#!": start = 1
        while start < len(lines) and string.strip(lines[start]) in ["", "#"]:
            start = start + 1
        if lines[start][:1] == "#":
            end = start
            while end < len(lines) and lines[end][:1] == "#": end = end + 1
            return string.join(lines[start:end], "")
        else: return None

    # Look for a preceding block of comments at the same indentation.
    elif lnum > 0:
        indent = indentsize(lines[lnum])
        end = lnum - 1
        if string.strip(lines[end]) == "":
            while end >= 0 and string.strip(lines[end]) == "":
                end = end - 1
        else:
            while string.lstrip(lines[end])[:1] != "#" and \
                indentsize(lines[end]) == indent:
                end = end - 1
        if end >= 0 and string.lstrip(lines[end])[:1] == "#" and \
            indentsize(lines[end]) == indent:
            comments = [string.lstrip(lines[end])]
            if end > 0:
                end = end - 1
                comment = string.lstrip(lines[end])
                while comment[:1] == "#" and indentsize(lines[end]) == indent:
                    comments[:0] = [comment]
                    end = end - 1
                    if end < 0: break
                    comment = string.lstrip(lines[end])
            return string.join(comments, "")

import tokenize

class ListReader:
    """Provide a readline() method to return lines from a list of strings."""
    def __init__(self, lines):
        self.lines = lines
        self.index = 0
    
    def readline(self):
        i = self.index
        if i < len(self.lines):
            self.index = i + 1
            return self.lines[i]
        else: return ""

class EndOfBlock(Exception): pass

class BlockFinder:
    """Provide a tokeneater() method to detect the end of a code block."""
    def __init__(self):
        self.indent = 0
        self.started = 0
        self.last = 0

    def tokeneater(self, type, token, (srow, scol), (erow, ecol), line):
        if not self.started:
            if type == tokenize.NAME: self.started = 1
        elif type == tokenize.NEWLINE:
            self.last = srow
        elif type == tokenize.INDENT:
            self.indent = self.indent + 1
        elif type == tokenize.DEDENT:
            self.indent = self.indent - 1
            if self.indent == 0: raise EndOfBlock, self.last

def getblock(lines):
    """Extract the block of code at the top of the given list of lines."""
    try:
        tokenize.tokenize(ListReader(lines).readline, BlockFinder().tokeneater)
    except EndOfBlock, eob:
        return lines[:eob.args[0]]

def getsource(object):
    """Try to get the source code corresponding to a module, class, method,
    function, traceback, frame, or code object.  Return a list of lines of
    text and the line number of the first line.  An IOError exception is
    raised if the source code cannot be retrieved."""
    lines, lnum = findsource(object)

    if ismodule(object): return lines
    else: return getblock(lines[lnum:]), lnum

# --------------------------------------------------- class tree extraction
def walktree(classes, children, parent):
    """Recursive helper function for getclasstree()."""
    results = []
    classes.sort(lambda a, b: a.__name__ > b.__name__ and 1 or -1)
    for c in classes:
        if c.__bases__ and c.__bases__ != (parent,):
            results.append((c, c.__bases__))
        else:
            results.append((c, None))
        if children.has_key(c):
            results.append(walktree(children[c], children, c))
    return results

def getclasstree(classes):
    """Arrange the given list of classes into a hierarchy of nested lists.
    Where a nested list appears, it contains classes derived from the class
    whose entry immediately precedes the list.  Each entry is a 2-tuple
    containing a class and a list of bases, where the list of bases only
    appears if the class inherits in a way not already implied by the tree
    structure (i.e. it inherits from multiple bases, or inherits from a
    class not present in the tree).  Exactly one entry appears in the
    returned structure for each class in the given list."""
    children = {}
    roots = []
    for c in classes:
        if c.__bases__:
            for parent in c.__bases__:
                if not children.has_key(parent):
                    children[parent] = []
                children[parent].append(c)
                if parent in classes: break
        elif c not in roots:
            roots.append(c)
    for parent in children.keys():
        if parent not in classes:
            roots.append(parent)
    return walktree(roots, children, None)

# ------------------------------------------------ argument list extraction
# These constants are from Python's compile.h.
CO_OPTIMIZED, CO_NEWLOCALS, CO_VARARGS, CO_VARKEYWORDS = 1, 2, 4, 8

def getargs(co):
    """Get information about the arguments accepted by a code object.
    Three things are returned: (args, varargs, varkw), where 'args' is
    a list of argument names (possibly containing nested lists), and
    'varargs' and 'varkw' are the names of the * and ** arguments or None."""
    if not iscode(co): raise TypeError, "arg is not a code object"

    code = co.co_code
    nargs = co.co_argcount
    names = co.co_varnames
    args = list(names[:nargs])
    step = 0

    # The following acrobatics are for anonymous (tuple) arguments.
    for i in range(nargs):
        if args[i][:1] in ["", "."]:
            stack, remain, count = [], [], []
            while step < len(code):
                op = ord(code[step])
                step = step + 1
                if op >= dis.HAVE_ARGUMENT:
                    opname = dis.opname[op]
                    value = ord(code[step]) + ord(code[step+1])*2
                    step = step + 2
                    if opname == "UNPACK_TUPLE":
                        remain.append(value)
                        count.append(value)
                    elif opname == "STORE_FAST":
                        stack.append(names[value])
                        remain[-1] = remain[-1] - 1
                        while remain[-1] == 0:
                            remain.pop()
                            size = count.pop()
                            stack[-size:] = [stack[-size:]]
                            if not remain: break
                            remain[-1] = remain[-1] - 1
                        if not remain: break
            args[i] = stack[0]

    varargs = None
    if co.co_flags & CO_VARARGS:
        varargs = co.co_varnames[nargs]
        nargs = nargs + 1
    varkw = None
    if co.co_flags & CO_VARKEYWORDS:
        varkw = co.co_varnames[nargs]
    return args, varargs, varkw

def getargspec(func):
    """Get the names and default values of a function's arguments.
    A tuple of four things is returned: (args, varargs, varkw, defaults).
    'args' is a list of the argument names (it may contain nested lists).
    'defaults' is an n-tuple of the default values of the last n arguments.
    'varargs' and 'varkw' are the names of the * and ** arguments or None."""
    if not isfunction(func): raise TypeError, "arg is not a Python function"
    args, varargs, varkw = getargs(func.func_code)
    return args, varargs, varkw, func.func_defaults

def getargvalues(frame):
    """Get information about arguments passed into a particular frame."""
    args, varargs, varkw = getargs(frame.f_code)
    return args, varargs, varkw, frame.f_locals

def strtuple(object, convert=str):
    """Recursively walk a tuple, stringifying each element."""
    if type(object) in [type(()), type([])]:
        results = map(strtuple, object)
        if len(results) == 1:
            return "(" + results[0] + ",)"
        else:
            return "(" + string.join(results, ", ") + ")"
    else: return convert(object)

def formatargspec(args, varargs=None, varkw=None, defaults=None,
                  argformat=str, defaultformat=lambda x: "=" + repr(x),
                  varargsformat=lambda name: "*" + name,
                  varkwformat=lambda name: "**" + name):
    """Make a nicely-formatted argument spec from the output of getargspec."""
    specs = []
    if defaults:
        firstdefault = len(args) - len(defaults)
    for i in range(len(args)):
        spec = strtuple(args[i], argformat)
        if defaults and i >= firstdefault:
            spec = spec + defaultformat(defaults[i - firstdefault])
        specs.append(spec)
    if varargs:
        specs.append(varargsformat(varargs))
    if varkw:
        specs.append(varkwformat(varkw))
    return "(" + string.join(specs, ", ") + ")"

def formatargvalues(args, varargs=None, varkw=None, locals=None,
                    argformat=str, valueformat=repr,
                    varargsformat=lambda name: "*" + name,
                    varkwformat=lambda name: "**" + name):
    """Make a nicely-formatted argument spec from the output of getargvalues."""
    def convert(name, locals=locals,
                argformat=argformat, valueformat=valueformat):
        return argformat(name) + "=" + valueformat(locals[name])
    specs = []
    for i in range(len(args)):
        specs.append(strtuple(args[i], convert))
    if varargs:
        specs.append(varargsformat(varargs))
    if varkw:   
        specs.append(varkwformat(varkw))
    return "(" + string.join(specs, ", ") + ")"
    
# -------------------------------------------------- stack frame extraction
def getframe(frame, context=1):
    """For a given frame or traceback object, return the filename, line
    number, function name, a given number of lines of context from the
    source code, and the index of the line within the lines of context."""
    if istraceback(frame):
        frame = frame.tb_frame
    if not isframe(frame):
        raise TypeError, "arg is not a frame or traceback object"

    filename = getsourcefile(frame)
    if context > 0:
        start = frame.f_lineno - 1 - context/2
        try:
            lines, lnum = findsource(frame)
            start = min(start, 1)
            start = max(start, len(lines) - context)
            lines = lines[start:start+context]
            index = frame.f_lineno - 1 - start
        except:
            lines = index = None
    else:
        lines = index = None

    return (filename, frame.f_lineno, frame.f_code.co_name, lines, index)

def getouterframes(frame, context=1):
    """Get a list of records for a frame and all higher (calling) frames.
    Each record contains a frame object, filename, line number, function
    name, the requested amount of context, and index within the context."""
    framelist = []
    while frame:
        framelist.append((frame,) + getframe(frame, context))
        frame = frame.f_back
    return framelist

def getinnerframes(traceback, context=1):
    """Get a list of records for a traceback's frame and all lower frames.
    Each record contains a frame object, filename, line number, function
    name, the requested amount of context, and index within the context."""
    traceback = traceback.tb_next
    framelist = []
    while traceback:
        framelist.append((frame,) + getframe(traceback, context))
        traceback = traceback.tb_next
    return framelist

def currentframe():
    """Return the frame object for the caller's stack frame."""
    try:
        raise "catch me"
    except:
        return sys.exc_traceback.tb_frame.f_back

def stack(context=1):
    """Return a list of records for the stack above the caller's frame."""
    return getouterframes(currentframe().f_back, context)

def trace(context=1):
    """Return a list of records for the stack below the current exception.""" 
    return getinnerframes(sys.exc_traceback, context)
