#! /usr/bin/python
# Copyright (c) 2009 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
# Written by Colin Watson for Canonical Ltd.

import os
import re

import apt
import apt_pkg
try:
    from debian import deb822, debian_support
except ImportError:
    from debian_bundle import deb822, debian_support

from utils import get_output_root


cache = apt.Cache()

srcrec = None
pkgsrc = None

re_print_uris_filename = re.compile(r"'.+?' (.+?) ")
re_comma_sep = re.compile(r'\s*,\s*')

def init_src_cache():
    """Build a source package cache."""
    global srcrec, pkgsrc
    if srcrec is not None:
        return

    print "Building source package cache ..."
    srcrec = {}
    pkgsrc = {}

    version = {}
    binaries = {}

    # This is a somewhat ridiculous set of workarounds for APT's anaemic
    # source package database. The SourceRecords interface is inordinately
    # slow, because it searches the underlying database every single time
    # rather than keeping real lists; there really is, as far as I can see,
    # no proper way to ask APT for the list of downloaded Sources index
    # files in order to parse it ourselves; and so we must resort to looking
    # at the output of 'sudo apt-get --print-uris update' to get the list of
    # downloaded Sources files, and running them through python-debian.

    listdir = apt_pkg.config.find_dir('Dir::State::Lists')

    sources = []
    for line in get_output_root(['apt-get', '--print-uris',
                                 'update']).splitlines():
        matchobj = re_print_uris_filename.match(line)
        if not matchobj:
            continue
        filename = matchobj.group(1)
        if filename.endswith('_Sources'):
            sources.append(filename)
            print "Using file %s for apt cache" % filename

    for source in sources:
        try:
            source_file = open(os.path.join(listdir, source))
        except IOError:
            continue
        try:
            tag_file = apt_pkg.TagFile(source_file)
            for src_stanza in tag_file:
                if ('package' not in src_stanza or
                    'version' not in src_stanza or
                    'binary' not in src_stanza):
                    continue
                src = src_stanza['package']
                if (src not in srcrec or
                    (debian_support.Version(src_stanza['version']) >
                     debian_support.Version(version[src]))):
                    srcrec[src] = str(src_stanza)
                    version[src] = src_stanza['version']
                    binaries[src] = src_stanza['binary']
        finally:
            source_file.close()

    for src, pkgs in binaries.iteritems():
        for pkg in re_comma_sep.split(pkgs):
            pkgsrc[pkg] = src


class MultipleProvidesException(RuntimeError):
    pass

seen_providers = {}

def get_real_pkg(pkg):
    """Get the real name of binary package pkg, resolving Provides."""
    if pkg in cache and cache[pkg].versions:
        return pkg

    providers = cache.get_providing_packages(pkg)
    if len(providers) == 0:
        seen_providers[pkg] = None
    elif len(providers) > 1:
        # If one of them is already installed, just pick one
        # arbitrarily. (Consider libstdc++-dev.)
        for provider in providers:
            if provider.is_installed:
                seen_providers[pkg] = provider.name
                break
        else:
            raise MultipleProvidesException, \
                "Multiple packages provide %s; package must select one" % pkg
    else:
        seen_providers[pkg] = providers[0].name
    return seen_providers[pkg]


def get_src_name(pkg):
    """Return the name of the source package that produces binary package
    pkg."""

    real_pkg = get_real_pkg(pkg)
    if real_pkg is None:
        real_pkg = pkg
    record = get_src_record(real_pkg)
    if record is not None and 'package' in record:
        return record['package']
    else:
        return None


def get_src_record(src):
    """Return a parsed source package record for source package src."""
    init_src_cache()
    record = srcrec.get(src)
    if record is not None:
        return deb822.Sources(record)
    # try lookup by binary package
    elif src in pkgsrc and pkgsrc[src] != src:
        return deb822.Sources(srcrec.get(pkgsrc[src]))
    else:
        return None


def get_pkg_record(pkg):
    """Return a parsed binary package record for binary package pkg."""
    return deb822.Packages(str(cache[pkg].candidate.record))


def get_src_version(src):
    record = get_src_record(src)
    if record is not None:
        return record['version']
    else:
        return None


def get_src_binaries(src):
    """Return all the binaries produced by source package src."""
    record = get_src_record(src)
    if record is not None:
        bins = [b[0]['name'] for b in record.relations['binary']]
        return [b for b in bins if b in cache]
    else:
        return None
