"""distutils.install_db

Code for the database of installed Python distributions.

"""

# XXX next steps:
#   1) write test cases for this module
#   2) integrate into install_* commands
# 2.5) take it to the Distutils-SIG
#   3) write a package manager

__revision__ = "$Id$"

import os, sys
import binascii, cStringIO, sha, rfc822

from distutils.dist import DistributionMetadata

_inst_db = None
def get_install_db ():
    global _inst_db
    if _inst_db is None:
        _inst_db = InstallationDatabase()
    return _inst_db

class InstallationDatabase:
    def __init__ (self, paths=None):
        """InstallationDatabase(path:string)
        Read the installation database rooted at the specified path.
        If path is None, INSTALLDB is used as the default.
        """
        self.paths = paths
        self._cache = {}

    def get_distribution (self, dist_name):
        """get_distribution(dist_name:string) : SoftwareDistribution
        Get the object corresponding to a single distribution.
        """
        try:
            return self._cache[dist_name]
        except KeyError:
            for distribution in self:
                self._cache[dist_name] = distribution
                if distribution.name == dist_name:
                    return distribution

            return None

    def list_distributions (self):
        """list_distributions() : [SoftwareDistribution]
        Return a list of all distributions installed on the system,
        enumerated in no particular order.
        """
        return list(self)

    def find_distribution (self, path):
        """find_file(path:string) : SoftwareDistribution
        Search and return the distribution containing the file 'path'.
        Returns None if the file doesn't belong to any distribution
        that the InstallationDatabase knows about.
        XXX should this work for directories?
        """
        for distribution in self:
            if distribution.has_file(path):
                return distribution
        return None

    def __iter__ (self):
        return _InstallDBIterator(self)

# class InstallationDatabase


class _InstallDBIterator:
    def __init__ (self, instdb):
        self.instdb = instdb
        if instdb.paths is None:
            self.queue = sys.path[:]
        else:
            self.queue = instdb.paths

    def next (self):
        if len(self.queue) == 0:
            raise StopIteration

        while len(self.queue):
            filename = self.queue.pop(0)
            if os.path.isdir(filename):
                for fn2 in os.listdir(filename):
                    self.queue.insert(0, os.path.join(filename, fn2))
            else:
                break

        return SoftwareDistribution(filename)


class SoftwareDistribution(DistributionMetadata):
    """Instance attributes:
    name : string
      Name of distribution
    filename : string
      Name of file in which the distribution's data is stored.
    files : {string : (size:int, perms:int, owner:string, group:string,
                       digest:string)}
      Dictionary mapping the path of a file installed by this distribution
      to information about the file.
    requires : [string]
      List of requirements for this distribution.
    provides : [string]
      List of modules provided by this distribution.
    conflicts : [string]
      List of distributions that conflict with this distribution.
    obsoletes : [string]
      List of distributions that are rendered obsolete by this distribution.
    """

    def __init__ (self, name=None, filename=None):
        DistributionMetadata.__init__(self)
        self.files = {}
        self.filename = filename
        self.requires = []
        self.provides = []
        self.conflicts = []
        self.obsoletes = []
        if filename is not None:
            self.read_file()

    def __repr__ (self):
        return '<Distribution %s: %s>' % (self.name, self.filename)

    def set_name (self, name):
        """set_name(name:string)
        Set the distribution name.
        """
        self.name = name
        if self.filename is None:
            self.filename = name

    def read_file (self):
        input = open(self.filename, 'rt')

        sections = input.readline().split()

        if 'PKG-INFO' in sections:
            m = rfc822.Message(input)
            self.read_pkg_info(m)

        if 'FILES' in sections:
            while 1:
                line = input.readline()
                if line.strip() == "":
                    break

                line = line.split()
                line = line[:6]
                path, size, perms, owner, group, shasum = line
                self.files[path] = (int(size), int(perms),
                                    owner, group, shasum)

        if 'REQUIRES' in sections:
            while 1:
                line = input.readline().strip()
                if line == "":
                    break
                self.requires.append(line)

        if 'PROVIDES' in sections:
            while 1:
                line = input.readline().strip()
                if line == "":
                    break
                self.provides.append(line)

        input.close()


    def add_file (self, path, compute_digest=True):
        """add_file(path:string):None
        Record the size, ownership, &c., information for an installed file.
        XXX as written, this would stat() the file.  Should the size/perms/
        checksum all be provided as parameters to this method instead?
        """
        if not os.path.isfile(path):
            return
        # XXX what to do when hashing: binary or text mode?
        digest = '-'
        if compute_digest:
            input = open(path, 'rb')
            digest = _hash_file(input)
            input.close()
        stats = os.stat(path)
        self.files[path] = (stats.st_size, stats.st_mode,
                            stats.st_uid, stats.st_gid, digest)


    def has_file (self, path):
        """has_file(path:string) : Boolean
        Returns true if the specified path belongs to a file in this
        distribution.
        """
        return self.files.has_key(path)


    def check_file (self, path):
        """check_file(path:string) : [string]
        Checks whether the file's size, checksum, and ownership match,
        returning a possibly-empty list of mismatches.
        """
        if not self.has_file(path):
            return ["File not part of this distribution"]
        
        # XXX what to do if the file doesn't exist?
        digest = None
        if os.path.exists(path):
            input = open(path, 'rb')
            digest = _hash_file(input)
            input.close()
        expected = self.files[path]
        stats = os.stat(path)

        L = []
        if stats.st_size != expected[0]:
            L.append('Modified size: %i (expected %i)' %
                     (stats.st_size, expected[0]))
        if stats.st_mode != expected[1]:
            L.append('Modified mode: %i (expected %i)' %
                     (stats.st_mode, expected[1]))
        if expected[2] != 'unknown' and stats.st_uid != expected[2]:
            L.append('Modified user ownership: %s (expected %s)' %
                     (stats.st_uid, expected[2]))
        if expected[3] != 'unknown' and stats.st_gid != expected[3]:
            L.append('Modified group ownership: %s (expected %s)' %
                     (stats.st_gid, expected[3]))
        if digest != expected[4]:
            L.append('Incorrect SHA digest')

        return L

    def as_text (self):
        output = cStringIO.StringIO()
        print >>output, 'PKG-INFO FILES REQUIRES PROVIDES'
        ##self._write_pkg_info(output)
        output.write('\n')
        for path, t in self.files.items():
            line = '%s\t%s\t%s\t%s\t%s\t%s' % (path,
                                               t[0], t[1], t[2], t[3], t[4])
            print >>output, line
        output.write('\n')
        for s in self.requires:
            output.write(s + '\n')
        output.write('\n')
        for s in self.provides:
            output.write(s + '\n')
        output.write('\n')
        return output.getvalue()

# class SoftwareDistribution

def _hash_file (input):
    h = sha.new()
    while 1:
        data = input.read(4096)
        if data == "":
            break
        h.update(data)
    digest = binascii.b2a_hex(h.digest())
    return digest


if __name__ == '__main__':
    db = InstallationDatabase(['/tmp/i/'])
    for p in db:
        print p.__dict__
    print db.list_distributions()
    f = open('/tmp/i2', 'wt')
    f.write(p.as_text())
    f.close()
    print p.check_file('/www/bin/apachectl')

