#!/usr/bin/env python
#
#   ConVirt   -  Copyright (c) 2008 Jd & Hap Hazard
#   ======
#
# ConVirt is a Xen management tool with a GTK based graphical interface
# that allows for performing the standard set of domain operations
# (start, stop, pause, kill, shutdown, reboot, snapshot, etc...). It
# also attempts to simplify certain aspects such as the creation of
# domains, as well as making the consoles available directly within the
# tool's user interface.
#
#
# This software is subject to the GNU Lesser General Public License (LGPL)
# and for details, please consult it at:
#
#    http://www.fsf.org/licensing/licenses/lgpl.txt
#


from ManagedNode import ManagedNode
from XenNode import XenNode
from Groups import ServerGroup
from constants import *
import utils
from pprint import pprint
    
import os, socket, tempfile, traceback

class GridManager:

    # creds_helper will eventually become part of XenNodeFactory
    def __init__(self, store, config, creds_helper):
        self.store = store   # persistent store
        self.config = config # config store
        self.group_list = {}
        self.node_list = {} # nodes that are not categorized

        # initialize stat from the store
        host_names  = store.getHosts()
        group_map  = store.getGroups()
        if group_map is None:
            group_map = {}
            
        group_names = group_map.keys()

        for name in group_names:
            self.group_list[name] = ServerGroup(name)

        if len(self.group_list) is 0:
            self._create_default_groups()
            
        for host in host_names:

            _remote = store.getHostProperty(prop_isRemote, host)
            if not _remote:
                remote = utils.is_host_remote(host)
                self.store.setHostProperty(prop_isRemote,str(remote),host)
            else:
                remote = eval(_remote)

            username = store.getHostProperty(prop_login, host)
            xen_port = store.getHostProperty(prop_xen_port, host)
            ssh_port = store.getHostProperty(prop_ssh_port, host)
            migration_port = store.getHostProperty(prop_migration_port, host)
            xen_protocol = store.getHostProperty(prop_xen_protocol, host)
            use_keys = store.getHostProperty(prop_use_keys, host)
            address = store.getHostProperty(prop_address, host)
            
            default_ssh_port = config.get(utils.XMConfig.CLIENT_CONFIG,
                                          prop_default_ssh_port)
            default_migration_port = config.get(utils.XMConfig.CLIENT_CONFIG,
                                          prop_default_migration_port)
            default_xen_protocol = config.get(utils.XMConfig.CLIENT_CONFIG,
                                              prop_default_xen_protocol)

            if ssh_port is None:
                if default_ssh_port is not None:
                    ssh_port = int(default_ssh_port)
                else:
                    ssh_port = 22

            if migration_port is None:
                if default_migration_port is not None:
                    migration_port = int(default_migration_port)
                else:
                    migration_port = 8002


            if xen_protocol is None:
                if default_xen_protocol is not None:
                    xen_protocol = default_xen_protocol
                else:
                    xen_protocol = "tcp"


            ssh_port = int(ssh_port)
            migration_port = int(migration_port)
            
            if use_keys is None:
                use_keys = False
            else:
                use_keys = eval(use_keys)
            
            print host, "remote =", remote, xen_protocol, ssh_port, "key based auth =", use_keys
            
            if username is None or username.strip() is "":
                username='root'
                
            node = XenNode(hostname=host,
                           username = username,
                           isRemote = remote,
                           tcp_port = xen_port,
                           helper = creds_helper,
                           migration_port = migration_port,
                           protocol = xen_protocol, # tcp, ssl, or ssh_tunnel
                           ssh_port = ssh_port,
                           use_keys = use_keys,
                           address  = address)

            grps = self._find_groups(host, group_map)
            
            if len(grps) == 0:
                self.node_list[node.hostname] = node
            else:
                for g in grps:
                    group = self.group_list.get(g)
                    if group is not None:
                        group._addNode(node)


    def _save_groups(self):
        self.store.saveGroups(self.group_list)

    def _save_node(self, node):
        props = { prop_xen_port : node.tcp_port,
                  prop_login : node.username,
                  prop_ssh_port : str(node.ssh_port),
                  prop_migration_port : str(node.migration_port),
                  prop_isRemote : str(node.isRemote),
                  prop_xen_protocol : node.protocol,
                  prop_use_keys : str(node.use_keys),
                  prop_address  : node.address}
        
        self.store.setHostProperties(props, node.hostname)

    def _create_default_groups(self):
        for name in ["Desktops", "QA Lab", "Servers"]:
            default_group = ServerGroup(name)
            self.group_list[default_group.name] = default_group
        self._save_groups()


    def _find_groups(self, node_name, group_map = None):
        # given a node name find all the group names to which it belongs
        grp = []

        if group_map is None:
            group_map = {}
            for g in self.group_list:
                group_map[g] = self.group_list[g].getNodeNames()
            
        for g in group_map.keys():
            if node_name in group_map[g]:
                grp.append(g)

        return grp


    def discoverNodes(self,ip_list):
        pass


    def getNodeNames(self, group_name = None):
        if group_name is None:
            return self.node_list.keys()
        else:
            group = self.group_list.get(group_name)
            if group is not None:
                return group.getNodeNames()


    def getNodeList(self, group_name = None):
        if group_name is None:
            return self.node_list
        else:
            group = self.group_list.get(group_name)
            if group is not None:
                return group.getNodeList()

    def getNode(self,name, group_name = None):
        if group_name is None:
            return self.node_list[name]
        else:
            group = self.group_list.get(group_name)
            if group is not None:
                return group.getNode(name)

    
    def addNode(self,node, group_name = None):
        if group_name is None:
            if self.node_list.get(node.hostname) is None:
                self._save_node(node)
                self.node_list[node.hostname] = node
            else:
                raise Exception("Server %s already exists" % node.hostname)
        else:
            group = self.group_list.get(group_name)
            if group is not None:
                group._addNode(node)
                self._save_node(node)
                self._save_groups()


    def removeNode(self,name,group_name = None):
        
        if group_name is None:
            if self.node_list.get(name) is not None:
                del self.node_list[name]
        else:
            group = self.group_list.get(group_name)
            if group is not None:
                group._removeNode(name)
                self._save_groups()

        # remove the node, if not part of any other groups
        groups = self._find_groups(name)
        if len(groups) == 0 and name not in self.node_list.keys():
            self.store.removeHost(name)

    def list_nodes(self):
        print "## DUMP =="
        for name in self.getNodeNames():
            print "Node name" ,  name
        for g in self.group_list:
            print "group ", g
        print "## END DUMP =="


    
    def cloneNode(self,source_node, dest):
        pass


    def migrateDomains(self, source_node, vm_list, dest, live, force=False):
        ex_list = []
        
        try:
            try:
                if len(vm_list) > 0 :
                    if not force:
                        (err_list, warn_list) = \
                                   source_node.migration_checks(vm_list,
                                                                dest, live)

                for vm in vm_list:
                    try:
                        #source_node.migrate_dom(vm.name, dest, live)
                        self.migrateDomain(vm.name,source_node,dest,live,
                                           force = True) # checks done 
                    except Exception, ex1:
                        #traceback.print_exc()
                        ex_list.append("Error migrating " + vm.name + " : " + str(ex1))

            except Exception, ex:
                #traceback.print_exc()
                raise ex
        finally: 
            if len(ex_list) > 0:
                msg = "Errors in migrate all operations \n"
                for m in ex_list:
                    msg = msg + m + "\n"
                raise Exception(msg)

        

    def migrateNode(self,source_node, dest, live, force = False):
        """ Migrate all vms on this node to a dest node."""
        vm_list = []
        for vm in source_node.get_doms():
            if not vm.isDom0():
                vm_list.append(vm)
                
        self.migrateDomains(source_node, vm_list, dest, live, force)

    def cloneDomain(self,source_dom_name,
                    source_node,
                    dest_node=None):
        pass
    
    def migrateDomain(self,source_dom_name,
                      source_node,
                      dest_node, live, force = False):
        dom = source_node.get_dom(source_dom_name)
        running = dom.is_resident
        
        if not force and running:
            (err_list, warn_list) = source_node.migration_checks([dom],
                                                                 dest_node,
                                                                 live)

        ## No good mechanism for sucess or failue till we cutover to
        # task / XenAPI
        try:
            if running:
                source_node.migrate_dom(source_dom_name, dest_node, live)
        except socket.timeout:
            print "ignoring timeout on migration "
            pass

        # move config files if necessary.
        self.move_config_file(source_dom_name, source_node, dest_node)

    def move_config_file(self, dom_name, source_node, dest_node):
        dom = source_node.get_dom(dom_name)
        if dom and dom.get_config():
            config = dom.get_config()
            if config.filename is not None:
                if dest_node.node_proxy.file_exists(config.filename):
                    # we are done:
                    pass
                else:
                    # create a temp file on the client node.
                    # and move it to the dest node.
                    (t_handle, t_name) = tempfile.mkstemp(prefix=dom_name,
                                                          dir="/tmp")
                    try:
                        source_node.node_proxy.get(config.filename,
                                                   t_name)
                        dest_node.node_proxy.put(t_name,
                                                 config.filename)
                        source_node.node_proxy.remove(config.filename)
                    finally:
                        os.close(t_handle)
                        os.remove(t_name)
                    
            # now lets reassociate the config with the new node.
            dest_node.add_dom_config(config.filename)
            source_node.remove_dom_config(config.filename)
                

    # server pool related functions.
    def getGroupNames(self):
        return self.group_list.keys()

    def getGroupList(self):
        return self.group_list

    def getGroup(self,name):
        return self.group_list[name]
    
    def addGroup(self,grp):
        if grp.name in self.group_list:
            raise Exception("Group already exists.")
        self.group_list[grp.name] = grp
        self._save_groups()

    def removeGroup(self,name, deep=False):
        if deep:
            for node_name in self.group_list[name].getNodeNames():
                self.removeNode(node_name, name)
        del self.group_list[name]
        self._save_groups()




        

    
        
        
