# bzr-avahi - share and browse Bazaar branches with mDNS
# Copyright (C) 2007-2008 James Henstridge
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

__metaclass__ = type
__all__ = [
    'Advertiser',
    'get_mdns_advertise',
    'get_mdns_name',
    'set_mdns_advertise',
    'set_mdns_name',
    ]

from bzrlib.lazy_import import lazy_import
lazy_import(globals(), """
import urlparse

import avahi
import dbus
from dbus.lowlevel import SignalMessage
from dbus.mainloop import glib as dbus_mainloop_glib

from bzrlib import errors, trace, urlutils
from bzrlib.branch import Branch
from bzrlib.bzrdir import BzrDir
from bzrlib.transport import get_transport
""")

MDNS_ADVERTISE = 'mdns-advertise'
MDNS_NAME = 'mdns-name'

# These values are pretty arbitrary, and may change.
MDNS_INTERFACE_NAME = 'org.bazaar_vcs.plugins.avahi.Notify'
MDNS_SIGNAL_NAME = 'BranchStateChanged'


class ServerInfo:
    def __init__(self, advertiser, backing_url, server_url):
        self.advertiser = advertiser
        self.backing_url = backing_url
        self.server_url = server_url

    def scan_branches(self):
        trans = get_transport(self.backing_url)
        for branch in BzrDir.find_branches(trans):
            trace.mutter('Found branch %r' % branch.base)
            self.handle_branch(branch)

    def handle_branch(self, branch):
        if not get_mdns_advertise(branch):
            trace.mutter('Removing %r', branch.base)
            self.advertiser.remove_branch(branch)
            return

        branch_url = branch.base
        relpath = urlutils.relative_url(self.backing_url, branch_url)
        public_url = urlutils.join(self.server_url, relpath)
        trace.mutter('Adding %r as %r', branch.base, public_url)
        self.advertiser.add_branch(self, branch, public_url)

    def maybe_handle_branch(self, branch):
        parent = self.backing_url
        if parent[-1] != '/':
            parent += '/'
        base = branch.base
        if base[-1] != '/':
            base += '/'
        if not base.startswith(parent):
            return False
        trace.mutter('Branch %r is handled by %r',
                     branch.base, self.backing_url)

        self.handle_branch(branch)
        return True


class BranchInfo:
    def __init__(self, advertiser, server, branch, public_url):
        self.advertiser = advertiser
        self.server = server
        self.branch = branch
        self.public_url = public_url
        self.group = None

    def get_name(self):
        """Get the mDNS Service name for this branch."""
        return get_mdns_name(self.branch)

    def handle_name_conflict(self):
        """Handle a naming collision."""
        name = self.get_name()
        new_name =  self.advertiser.avahi.GetAlternativeServiceName(name)
        trace.warning('Service name collision for %r.  Renaming %r to %r.',
                      self.branch.base, name, new_name)
        set_mdns_name(self.branch, new_name)
        assert self.get_name() == new_name

    def make_group(self):
        assert self.group is None
        self.group = dbus.Interface(
            self.advertiser.system_bus.get_object(
                avahi.DBUS_NAME, self.advertiser.avahi.EntryGroupNew()),
            avahi.DBUS_INTERFACE_ENTRY_GROUP)
        self.group.connect_to_signal('StateChanged', self.state_changed)

    def add_service(self):
        if self.group is None:
            self.make_group()
        assert self.group.IsEmpty()

        scheme, authority, path, query, fragment = urlparse.urlsplit(
            self.public_url)
        port = 0
        if ':' in authority:
            port = int(authority.rsplit(':', 1)[1])
        txt = ['path=%s' % path, 'scheme=%s' % scheme]
        while True:
            try:
                self.group.AddService(
                    avahi.IF_UNSPEC, avahi.PROTO_UNSPEC, dbus.UInt32(0),
                    self.get_name(), '_bzr._tcp', '', '', dbus.UInt16(port),
                    avahi.string_array_to_txt_array(txt))
            except dbus.DBusException, exc:
                # Catch only CollisionErrors
                if exc.get_dbus_name() != ('org.freedesktop.Avahi.'
                                           'CollisionError'):
                    raise
                self.handle_name_conflict()
            else:
                break
        self.group.Commit()

    def remove_service(self):
        if self.group is not None:
            self.group.Reset()

    def state_changed(self, state, status=None):
        if state == avahi.ENTRY_GROUP_ESTABLISHED:
            trace.mutter('Service %r established.' % self.get_name())
        elif state == avahi.ENTRY_GROUP_COLLISION:
            self.handle_name_conflict()
            self.remove_service()
            self.add_service()

    def close(self):
        if self.group is not None:
            self.group.Free()
            self.group = None


class Advertiser:
    def __init__(self):
        self.servers = {}
        self.branches = {}
        self.thread = None
        self.main = None
        self.server_state = None

    def start(self):
        dbus_mainloop_glib.DBusGMainLoop(set_as_default=True)
        self.session_bus = dbus.SessionBus()
        self.session_bus.add_signal_receiver(
            self.branch_state_changed,
            dbus_interface=MDNS_INTERFACE_NAME,
            signal_name=MDNS_SIGNAL_NAME)
        self.system_bus = dbus.SystemBus()
        self.avahi = dbus.Interface(
            self.system_bus.get_object(avahi.DBUS_NAME, avahi.DBUS_PATH_SERVER),
            avahi.DBUS_INTERFACE_SERVER)
        self.avahi.connect_to_signal('StateChanged', self.server_state_changed)
        self.server_state = self.avahi.GetState()

    def server_state_changed(self, state, status=None):
        self.server_state = state
        if state == avahi.SERVER_COLLISION:
            for branchinfo in self.branches.itervalues():
                branchinfo.remove_service()
        elif state == avahi.SERVER_RUNNING:
            for branchinfo in self.branches.itervalues():
                branchinfo.add_service()

    def branch_state_changed(self, branch_url):
        branch_url = urlutils.strip_trailing_slash(branch_url)
        trace.mutter('Received branch state change notification for %r',
                     branch_url)

        branchinfo = self.branches.get(branch_url)
        if branchinfo is None:
            try:
                branch = Branch.open(branch_url)
            except errors.NoSuchBranch:
                trace.mutter('Invalid branch')
                return
            # See if any of our server processes want to handle the branch.
            for server in self.servers.itervalues():
                if server.maybe_handle_branch(branch):
                    break
        else:
            # Get the previously registered server to reprocess the
            # branch.
            branchinfo.server.maybe_handle_branch(branchinfo.branch)

    def add_branch(self, server, branch, public_url):
        branch_url = urlutils.strip_trailing_slash(branch.base)
        bi = self.branches.get(branch_url)
        if bi is None:
            bi = BranchInfo(self, server, branch, public_url)
            self.branches[branch_url] = bi
        else:
            # readvertise the branch
            bi.remove_service()
        if self.server_state == avahi.SERVER_RUNNING:
            bi.add_service()

    def remove_branch(self, branch):
        branch_url = urlutils.strip_trailing_slash(branch.base)
        bi = self.branches.get(branch_url)
        if bi is not None:
            bi.close()
            del self.branches[branch_url]

    def add_server(self, backing_urls, server_url):
        for backing_url in backing_urls:
            if backing_url.startswith('readonly+'):
                backing_url = backing_url[len('readonly+'):]
            if backing_url.startswith('file:'):
                break
        else:
            # No usable backing URL
            return
        # If this server is already registered, return.
        if server_url in self.servers:
            return
        server = ServerInfo(self, backing_url, server_url)
        self.servers[server_url] = server
        server.scan_branches()

    def remove_server(self, server_url):
        if server_url not in self.servers:
            return
        branches = [url for (url, branchinfo) in self.branches.iteritems()
                    if branchinfo.server.server_url == server_url]
        for url in branches:
            self.branches[url].close()
            del self.branches[url]
        del self.servers[server_url]


def get_mdns_advertise(branch):
    config = branch.get_config()
    return config.get_user_option(MDNS_ADVERTISE) == 'True'


def set_mdns_advertise(branch, advertised):
    config = branch.get_config()
    config.set_user_option(MDNS_ADVERTISE, str(bool(advertised)))


# XXX: hashes in options do not get preserved, so we mangle them here.
# This can go when https://bugs.launchpad.net/bugs/86838 is fixed.
def get_mdns_name(branch):
    config = branch.get_config()
    name = config.get_user_option(MDNS_NAME)
    if name is not None:
        return name.replace('\\x23', '#')
    return branch.nick


def set_mdns_name(branch, name):
    config = branch.get_config()
    config.set_user_option(MDNS_NAME, name.replace('#', '\\x23'))


def send_change_notification(branch, _bus_factory=None):
    """Notify listeners that the branch state has changed.

    The notification is sent as a signal on the D-Bus session bus.
    """
    if _bus_factory is None:
        bus = dbus.SessionBus()
    else:
        bus = _bus_factory()
    message = SignalMessage('/', MDNS_INTERFACE_NAME, MDNS_SIGNAL_NAME)
    message.append(branch.base, signature='s')
    bus.send_message(message)
