#
# Copyright (C) 2005 Chris Halls <halls@debian.org>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

"""Unit test for apt_proxy.py"""

import copy, tempfile, os, shutil

from twisted.trial import unittest
from twisted.internet import reactor
from StringIO import StringIO

from apt_proxy.apt_proxy_conf import apConfig
from apt_proxy.apt_proxy import Factory
from apt_proxy.clients import HttpRequestClient
from apt_proxy.misc import log

config1="""
[DEFAULT]
debug=all:9
port=9999
address=
cleanup_freq=off
max_versions=off

[backend1]
backends = http://a.b.c/d

[backend2]
backends = http://d.e.f/g

[backend3]
backends = http://h.i.j/k
"""

config2="""
[DEFAULT]
debug=all:1
port=8888
address=1.2.3.4
cleanup_freq=off

# Backend 1 deleted

[backend2]
# no change
backends = http://d.e.f/g

[backend3]
# changed
backends = http://l.m.n/o

[backend4]
# new
backends = http://p.q.r/s

[backend5]
# another new
backends = http://t.u.v/w
"""

class apTestHelper(unittest.TestCase):
    default_config = "[DEFAULT]\ndebug=all:9 apt:0 memleak:0\ncleanup_freq=off\n" # Config string to use
    def setUp(self):
        self.cache_dir = tempfile.mkdtemp('.aptproxy')
        self.config = self.default_config.replace('[DEFAULT]','[DEFAULT]\ncache_dir=' + self.cache_dir)
    def tearDown(self):
        # Allow connections to close down etc.
        #reactor.iterate(0.1)
        #reactor.iterate(0.1)
        #reactor.iterate(0.1)
        log.debug('Removing temporary directory: ' + self.cache_dir)
        shutil.rmtree(self.cache_dir)
        self.assertRaises(OSError, os.stat, self.cache_dir)

class FactoryTestHelper(apTestHelper):
    """
    Set up a cache dir and a factory
    """
    def setUp(self, config):
        """
        Set up a factory using the additional config given
        """
        apTestHelper.setUp(self)
        config = self.config + '\n' + config
        self.apConfig = apConfig(StringIO(config))
        self.factory = Factory(self.apConfig)
        self.factory.configurationChanged()

    def tearDown(self):
        self.factory.stopFactory()
        del(self.factory)
        apTestHelper.tearDown(self)
        self.assertRaises(OSError, os.stat, self.cache_dir)

class FactoryInitTest(apTestHelper):
    def setUp(self):
        self.default_config = config1
        apTestHelper.setUp(self)
        self.c = apConfig(StringIO(self.config))
    def testFactoryInit(self):
        factory = Factory(self.c)
        self.assertEquals(factory.config, self.c)
        del factory
    def testFactoryBackendInit(self):
        factory = Factory(self.c)
        factory.configurationChanged()
        self.assertEquals(len(factory.backends),3)
        self.assertEquals(factory.backends.keys(), ['backend1', 'backend2', 'backend3'])
        self.assertEquals(factory.backends['backend1'].uris[0].host, 'a.b.c')
        del factory

class StartFactoryTest(unittest.TestCase):
    def setUp(self):
        self.cache_dir = tempfile.mkdtemp('.aptproxy')
        config = config1.replace('[DEFAULT]','[DEFAULT]\ncache_dir=' + self.cache_dir)
        self.c = apConfig(StringIO(config))
    def tearDown(self):
        shutil.rmtree(self.cache_dir)
    def testFactoryStart(self):
        factory = Factory(self.c)
        self.assertEquals(factory.recycler, None)
        factory.startFactory
        self.assertEquals(factory.recycler, None)
    def testPeriodicOff(self):
        "Verify periodic callback is off"
        factory = Factory(self.c)
        factory.startFactory
        self.assertEquals(factory.periodicCallback, None)

class ConfigChangeTest(unittest.TestCase):
    def setUp(self):
        self.tempdir = tempfile.mkdtemp('.aptproxy')
        configOld = config1.replace('[DEFAULT]','[DEFAULT]\ncache_dir=%s/old'%(self.tempdir))
        self.cOld = apConfig(StringIO(configOld))
        self.factory = Factory(self.cOld)
        self.factory.configurationChanged()
    def tearDown(self):
        del(self.factory)
        shutil.rmtree(self.tempdir)
    def loadNewConfig(self):
        configNew = config2.replace('[DEFAULT]','[DEFAULT]\ncache_dir=%s/new'%(self.tempdir))
        self.cNew = apConfig(StringIO(configNew))
        self.factory.config = copy.copy(self.cNew)
        self.factory.configurationChanged(self.cOld)
    def testNotAllowedChanges(self):
        self.loadNewConfig()
        self.assertNotEquals(self.factory.config.port, self.cNew.port)
        self.assertEquals(self.factory.config.port, self.cOld.port)
        self.assertEquals(self.factory.config.address, self.cOld.address)
    def testGlobalChanges(self):
        self.loadNewConfig()
        self.assertEquals(self.factory.config.debug, 'all:1')
        self.assertEquals(self.factory.config.debug, self.cNew.debug)
    def testBackendCount(self):
        self.loadNewConfig()
        self.assertEquals(len(self.factory.backends),4)
    def testBackendChanges(self):
        self.assertEquals(self.factory.backends['backend3'].uris[0].host, 'h.i.j')
        self.loadNewConfig()
        self.assertEquals(self.factory.backends.keys(), ['backend2', 'backend3', 'backend4', 'backend5'])
        self.assertEquals(self.factory.backends['backend3'].uris[0].host, 'l.m.n')

class FactoryFnsTest(FactoryTestHelper):
    """
    Set up a cache dir and a factory
    """
    def setUp(self):
        """
        Set up a factory using the additional config given
        """
        FactoryTestHelper.setUp(self, config1.replace("cleanup_freq=off", "cleanup_freq=1h"))

    def testPeriodicControl(self):
        "Start & stop periodic callback"
        self.assertNotEquals(self.factory.periodicCallback, None)
        self.factory.stopPeriodic()
        self.assertEquals(self.factory.periodicCallback, None)
        self.factory.startPeriodic()
        self.assertNotEquals(self.factory.periodicCallback, None)
        self.factory.stopPeriodic()
        self.assertEquals(self.factory.periodicCallback, None)
    def testPeriodic(self):
        "Run periodic cleaning"
        self.factory.startFactory() # Start recycler
        self.factory.stopPeriodic() # Stop periodic callback
        self.factory.periodic()     # And trigger it manually
        self.assertNotEquals(self.factory.periodicCallback, None)
        self.factory.stopPeriodic() # Cancel new callback
        self.assertEquals(self.factory.periodicCallback, None)


    def testDumpDbs(self):
        "Test that factory.dumpdbs() runs to completion"
        self.factory.dumpdbs()

class FactoryVersionsTest(FactoryTestHelper):
    """
    Set up a cache dir and a factory
    """
    def setUp(self):
        """
        Set up a factory using the additional config given
        """
        FactoryTestHelper.setUp(self, config1.replace("max_versions=off", "max_versions=2"))

    def testFirstFileServed(self):
        "Add non-.deb to databases"
        file = 'debian/dists/stable/Release.gpg'
        path = os.sep + file
        self.failIf(self.factory.access_times.has_key(path))
        self.factory.file_served(file)
        self.failUnless(self.factory.access_times.has_key(path))
        # This is not a versioned file
        self.failIf(self.factory.packages.has_key(path))

    def testDebServed1(self):
        "Add new .deb to databases"
        file = 'debian/nonexistent_1.0.deb'
        path = os.sep + file
        packagename = 'nonexistent'
        self.failIf(self.factory.access_times.has_key(path))
        self.failIf(self.factory.packages.has_key(packagename))
        self.factory.file_served(file)
        self.failUnless(self.factory.access_times.has_key(path))
        # This is not a versioned file
        self.failUnless(self.factory.packages.has_key(packagename))
        pkgs = self.factory.packages[packagename]
        self.assertEquals(len(pkgs), 1)

    def testDebServed2(self):
        "Add two .debs to databases"
        file1 = 'debian/nonexistent_1.0.deb'
        file2 = file1.replace('1.0', '1.1')
        packagename = 'nonexistent'
        self.factory.file_served(file1)
        self.factory.file_served(file2)
        self.failUnless(self.factory.packages.has_key(packagename))
        pkgs = self.factory.packages[packagename]
        self.assertEquals(len(pkgs), 2)

    def testDebServed3(self):
        "Test max_versions algorithm"
        files = []
        versions = ['0.0.1', '0.0.2', '0.0.3']
        packagename = 'apt'
        os.mkdir(self.cache_dir + os.sep + 'backend1')
        for ver in versions:
            package_filename='apt_'+ver+'_test.deb'
            file = 'backend1'+os.sep+package_filename
            shutil.copy2('../test_data/apt/'+package_filename, self.cache_dir + os.sep + file)
            self.factory.file_served(file)
            files.append(file)
        pkgs = self.factory.packages[packagename]
        # Max versions should have deleted one file
        self.assertEquals(len(pkgs), 2)

backendServerConfig = """
[test_servers]
backends=http://server1/path1
         ftp://server2/path2
         rsync://server3/path3
         file://server4/path4
[test_usernames]
backends=http://myUser:thePassword@httpserver/httppath
         ftp://myFtpUser:theFtpPassword@ftpserver/ftppath
"""
class BackendServerTest(FactoryTestHelper):
    def setUp(self):
        """
        Set up a factory using the additional config given
        """
        FactoryTestHelper.setUp(self, backendServerConfig)
        self.backend = self.factory.getBackend('test_servers')
        self.backend2 = self.factory.getBackend('test_usernames')

    def testServerHosts(self):
        values = ['server1','server2','server3','server4']
        for server in self.backend.uris:
            value = values[self.backend.uris.index(server)]
            self.assertEquals(server.host, value)
        values = ['httpserver','ftpserver']
        for server in self.backend2.uris:
            value = values[self.backend2.uris.index(server)]
            self.assertEquals(server.host, value)
    def testServerPaths(self):
        values = ['/path1','/path2','/path3','/path4']
        for server in self.backend.uris:
            value = values[self.backend.uris.index(server)]
            self.assertEquals(server.path, value)
        values = ['/httppath','/ftppath']
        for server in self.backend2.uris:
            value = values[self.backend2.uris.index(server)]
            self.assertEquals(server.path, value)
    def testServerProtocols(self):
        values = ['http','ftp','rsync','file']
        for server in self.backend.uris:
            value = values[self.backend.uris.index(server)]
            self.assertEquals(server.scheme, value)
    def testServerDefaultPorts(self):
        values = [80,21,873,0]
        for server in self.backend.uris:
            value = values[self.backend.uris.index(server)]
            self.assertEquals(server.port, value)
    def testStr(self):
        "__str__ operator"
        for server in self.backend.uris:
            self.assertNotEquals(server.__str__(), None)
    def testNoUser(self):
        self.assertEquals(self.backend.uris[0].username,None)
    def testNoPassword(self):
        self.assertEquals(self.backend.uris[0].password,None)
    def testUser(self):
        self.assertEquals(self.backend2.uris[0].username,'myUser')
        self.assertEquals(self.backend2.uris[1].username,'myFtpUser')
    def testPassword(self):
        backend = self.factory.getBackend('test_usernames')
        self.assertEquals(self.backend2.uris[0].password,'thePassword')
        self.assertEquals(self.backend2.uris[1].password,'theFtpPassword')
