import os

from twisted.internet.defer import inlineCallbacks, returnValue

from juju.control.utils import get_environment
from juju.state.errors import (
    ServiceUnitStateMachineNotAssigned, MachineStateNotFound)
from juju.state.machine import MachineStateManager
from juju.state.service import ServiceStateManager
from juju.state.sshforward import prepare_ssh_sharing


def configure_subparser(subparsers):
    sub_parser = subparsers.add_parser("ssh", help=command.__doc__)
    sub_parser.add_argument(
        "unit_or_machine", help="Name of unit or machine")
    sub_parser.add_argument(
        "--environment", "-e", help="Environment to operate on.")
    return sub_parser


@inlineCallbacks
def get_ip_address_for_unit(client, provider, unit_name):
    manager = ServiceStateManager(client)
    service_unit = yield manager.get_unit_state(unit_name)

    machine_id = yield service_unit.get_assigned_machine_id()
    if machine_id is None:
        raise ServiceUnitStateMachineNotAssigned(unit_name)

    returnValue(
        ((yield service_unit.get_public_address()), service_unit))


@inlineCallbacks
def get_ip_address_for_machine(client, provider, machine_id):
    manager = MachineStateManager(client)
    machine_state = yield manager.get_machine_state(machine_id)
    instance_id = yield machine_state.get_instance_id()
    provider_machine = yield provider.get_machine(instance_id)
    returnValue((provider_machine.dns_name, machine_state))


def open_ssh(ip_address):
    # XXX - TODO - Might be nice if we had the ability to get the user's
    # private key path and utilize it here, ie the symmetric end to
    # get user public key.
    args = ["ssh"]
    args.extend(prepare_ssh_sharing())
    args.extend(["ubuntu@%s" % ip_address])
    os.execvp("ssh", args)


@inlineCallbacks
def command(options):
    """Launch an ssh shell on the given unit or machine.
    """
    environment = get_environment(options)
    provider = environment.get_machine_provider()
    client = yield provider.connect()

    label = machine = unit = None

    # First check if its a juju machine id
    if options.unit_or_machine.isdigit():
        options.log.debug(
            "Fetching machine address using juju machine id.")
        ip_address, machine = yield get_ip_address_for_machine(
            client, provider, options.unit_or_machine)
        machine.get_ip_address = get_ip_address_for_machine
        label = "machine"
    # Next check if its a unit
    elif "/" in options.unit_or_machine:
        options.log.debug(
            "Fetching machine address using unit name.")
        ip_address, unit = yield get_ip_address_for_unit(
            client, provider, options.unit_or_machine)
        unit.get_ip_address = get_ip_address_for_unit
        label = "unit"
    else:
        raise MachineStateNotFound(options.unit_or_machine)

    agent_state = machine or unit

    # Now verify the relevant agent is operational via its agent.
    exists_d, watch_d = agent_state.watch_agent()
    exists = yield exists_d

    if not exists:
        # If not wait on it.
        options.log.info("Waiting for %s to come up." % label)
        yield watch_d

    # Double check the address we have is valid, else refetch.
    if ip_address is None:
        ip_address, machine = yield agent_state.get_ip_address(
            client, provider, options.unit_or_machine)

    yield client.close()

    options.log.info("Connecting to %s %s at %s",
                     label, options.unit_or_machine, ip_address)
    open_ssh(ip_address)
