# -*- coding: utf-8 -*-
#
# Author: Natalia B. Bidart <natalia.bidart@canonical.com>
# Author: Alejandro J. Cura <alecu@canonical.com>
#
# Copyright (C) 2009, 2011 Canonical Ltd.
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License version 3,
# as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranties of
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR
# PURPOSE.  See the GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""Tests for the protocol client."""

import StringIO
import os
import sys
import time
import unittest
import uuid

from twisted.application import internet, service
from twisted.internet import defer
from twisted.internet.defer import Deferred, inlineCallbacks
from twisted.trial.unittest import TestCase as TwistedTestCase
from twisted.web import server, resource

from ubuntuone.storageprotocol import protocol_pb2, sharersp, delta, request
from ubuntuone.storageprotocol.client import (
    StorageClient, CreateUDF, ListVolumes, DeleteVolume, GetDelta, Unlink,
    Authenticate, MakeFile, MakeDir, PutContent, Move, BytesMessageProducer,
    oauth, TwistedTimestampChecker, tx_timestamp_checker,
)
from ubuntuone.storageprotocol import volumes
from tests import test_delta_info

# let's not get picky about aatributes outside __init__ in tests
# pylint: disable=W0201
# it's ok to access internals in the test suite
# pylint: disable=W0212

PATH = u'~/Documents/pdfs/moño/'
NAME = u'UDF-me'
VOLUME = uuid.UUID('12345678-1234-1234-1234-123456789abc')
SHARE = uuid.UUID('33333333-1234-1234-1234-123456789abc')
NODE = uuid.UUID('FEDCBA98-7654-3211-2345-6789ABCDEF12')
USER = u'Dude'
GENERATION = 999


class FakedError(Exception):
    """Stub to replace Request.error."""


def stub_function(*args, **kwargs):
    """Stub to replace non-related functions."""
    return None


def faked_error(message):
    """Stub to replace Request.error."""
    raise FakedError


def was_called(self, flag):
    """Helper to assert a function was called."""
    assert not getattr(self, flag)

    def set_flag(*args, **kwargs):
        """Record the calling was made."""
        setattr(self, flag, True)
    return set_flag


def build_list_volumes():
    """Build a LIST_VOLUMES message."""
    message = protocol_pb2.Message()
    message.type = protocol_pb2.Message.VOLUMES_INFO
    return message


def build_volume_created():
    """Build VOLUME_CREATED message."""
    message = protocol_pb2.Message()
    message.type = protocol_pb2.Message.VOLUME_CREATED
    return message


def build_volume_deleted():
    """Build VOLUME_DELETED message."""
    message = protocol_pb2.Message()
    message.type = protocol_pb2.Message.VOLUME_DELETED
    return message


def set_root_message(message):
    """Set a simple Root message."""
    message.type = protocol_pb2.Volumes.ROOT
    message.root.node = str(NODE)


def set_udf_message(message):
    """Set a simple UDFs message."""
    message.type = protocol_pb2.Volumes.UDF
    message.udf.volume = str(VOLUME)
    message.udf.node = str(NODE)
    message.udf.suggested_path = PATH


def set_share_message(message):
    """Set a simple Share message."""
    message.type = protocol_pb2.Volumes.SHARE
    message.share.share_id = str(VOLUME)
    message.share.direction = 0
    message.share.subtree = str(NODE)
    message.share.share_name = u'test'
    message.share.other_username = USER
    message.share.other_visible_name = USER
    message.share.accepted = False
    message.share.access_level = 0


class MethodMock(object):
    """A class to overwrite methods to know if they were called.

    @ivar called: boolean, true if the class was called
    @ivar call_count: int, the number of calls
    """

    def __init__(self):
        """Create the mock."""
        self.called = False
        self.call_count = 0

    def __call__(self):
        """Update call stats."""
        self.called = True
        self.call_count += 1


class DummyAttribute(object):
    """Helper class to replace non-related classes."""

    def __getattribute__(self, name):
        """Attributes can be whatever we need."""
        return stub_function


class FakedProtocol(StorageClient):
    """Fake StorageClient to avoid twisted."""

    def __init__(self, *args, **kwargs):
        """Override transports and keep track of messages."""
        StorageClient.__init__(self, *args, **kwargs)
        self.transport = DummyAttribute()
        self.messages = []

    def sendMessage(self, message):
        """Keep track of messages."""
        self.messages.append(message)


class ClientTestCase(unittest.TestCase):
    """Check that MultiQuery works using an iterator."""

    def setUp(self):
        """Initialize testing client."""
        self.client = FakedProtocol()
        self.called = False
        self.volume = None

    def tearDown(self):
        """Clean up."""
        self.client = None

    def test_init_maxpayloadsize(self):
        """Get the value from the constant at init time."""
        self.assertEqual(self.client.max_payload_size,
                         request.MAX_PAYLOAD_SIZE)

    # client to server
    def test_client_get_delta(self):
        """Get a delta."""
        original = GetDelta.start
        GetDelta.start = was_called(self, 'called')

        try:
            result = self.client.get_delta(share_id=SHARE, from_generation=0)
            self.assertTrue(self.called, 'GetDelta.start() was called')
            self.assertTrue(isinstance(result, Deferred))
        finally:
            GetDelta.start = original

    def test_client_get_delta_from_scratch(self):
        """Get a delta from scratch."""
        original = GetDelta.start
        GetDelta.start = was_called(self, 'called')

        try:
            result = self.client.get_delta(share_id=SHARE, from_scratch=True)
            self.assertTrue(self.called, 'GetDelta.start() was called')
            self.assertTrue(isinstance(result, Deferred))
        finally:
            GetDelta.start = original

    def test_client_get_delta_bad(self):
        """Require from_generation or from_scratch."""

        self.assertRaises(TypeError, self.client.get_delta,
                          share_id=SHARE, callback=1)

    def test_create_udf(self):
        """Test create_udf."""
        original = CreateUDF.start
        CreateUDF.start = was_called(self, 'called')

        try:
            result = self.client.create_udf(path=PATH, name=NAME)
            self.assertTrue(self.called, 'CreateUDF.start() was called')
            self.assertTrue(isinstance(result, Deferred))
        finally:
            CreateUDF.start = original

    def test_list_volumes(self):
        """Test list_volumes."""
        original = ListVolumes.start
        ListVolumes.start = was_called(self, 'called')

        try:
            result = self.client.list_volumes()
            self.assertTrue(self.called, 'ListVolumes.start() was called')
            self.assertTrue(isinstance(result, Deferred))
        finally:
            ListVolumes.start = original

    def test_delete_volume(self):
        """Test delete_volume."""
        original = DeleteVolume.start
        DeleteVolume.start = was_called(self, 'called')

        try:
            result = self.client.delete_volume(volume_id=VOLUME)
            self.assertTrue(self.called, 'DeleteVolume.start() was called')
            self.assertTrue(isinstance(result, Deferred))
        finally:
            DeleteVolume.start = original

    def test_set_volume_deleted_callback(self):
        """Test callback setting."""
        a_callback = lambda x: None
        self.client.set_volume_deleted_callback(a_callback)
        self.assertTrue(self.client._volume_deleted_callback is a_callback)

    def test_callback_must_be_callable(self):
        """Test set callback parameters."""
        self.assertRaises(TypeError, self.client.set_volume_created_callback,
                          'hello')

        self.assertRaises(TypeError, self.client.set_volume_deleted_callback,
                          'world')

        self.assertRaises(TypeError,
                          self.client.set_volume_new_generation_callback, 'fu')

    def test_set_volume_created_callback(self):
        """Test callback setting."""
        a_callback = lambda y, z: None
        self.client.set_volume_created_callback(a_callback)
        self.assertTrue(self.client._volume_created_callback is a_callback)

    def test_set_volume_new_generation_callback(self):
        """Test callback setting."""
        cback = lambda y, z: None
        self.client.set_volume_new_generation_callback(cback)
        self.assertTrue(self.client._volume_new_generation_callback is cback)

    # share notification callbacks
    def test_share_change_callback(self):
        """Test share_change callback usage."""
        self.assertRaises(TypeError, self.client.set_share_change_callback,
                          'hello')
        #create a response and message
        share_resp = sharersp.NotifyShareHolder.from_params(
           uuid.uuid4(), uuid.uuid4(), 'sname', 'byu', 'tou', 'View')
        proto_msg = protocol_pb2.Message()
        proto_msg.type = protocol_pb2.Message.NOTIFY_SHARE
        share_resp.dump_to_msg(proto_msg.notify_share)

        #wire up a call back and make sure it's correct
        self.share_notif = None
        a_callback = lambda notif: setattr(self, 'share_notif', notif)
        self.client.set_share_change_callback(a_callback)
        self.assertTrue(self.client._share_change_callback is a_callback)
        self.client.handle_NOTIFY_SHARE(proto_msg)
        self.assertEquals(self.share_notif.share_id, share_resp.share_id)

    def test_share_delete_callback(self):
        """Test share_delete callback usage."""
        self.assertRaises(TypeError, self.client.set_share_delete_callback,
                          'hello')

        share_id = uuid.uuid4()
        proto_msg = protocol_pb2.Message()
        proto_msg.type = protocol_pb2.Message.SHARE_DELETED
        proto_msg.share_deleted.share_id = str(share_id)

        #wire up a call back and make sure it's correct
        self.deleted_share_id = None
        a_callback = lambda notif: setattr(self, 'deleted_share_id', notif)
        self.client.set_share_delete_callback(a_callback)
        self.assertTrue(self.client._share_delete_callback is a_callback)
        self.client.handle_SHARE_DELETED(proto_msg)
        self.assertEquals(self.deleted_share_id, share_id)

    def test_share_answer_callback(self):
        """Test share_answer callback usage."""
        self.assertRaises(TypeError, self.client.set_share_answer_callback,
                          'hello')

        share_id = uuid.uuid4()
        proto_msg = protocol_pb2.Message()
        proto_msg.type = protocol_pb2.Message.SHARE_ACCEPTED
        proto_msg.share_accepted.share_id = str(share_id)
        proto_msg.share_accepted.answer = protocol_pb2.ShareAccepted.YES

        #wire up a call back and make sure it's correct
        self.answer = None
        a_callback = lambda s, a: setattr(self, 'answer', (s, a))
        self.client.set_share_answer_callback(a_callback)
        self.assertTrue(self.client._share_answer_callback is a_callback)
        self.client.handle_SHARE_ACCEPTED(proto_msg)
        self.assertEquals(self.answer[0], share_id)
        self.assertEquals(self.answer[1], "Yes")

    def test_handle_volume_new_generation_uuid(self):
        """Test handle_VOLUME_NEW_GENERATION with an uuid id."""
        # set the callback to record the info
        called = []
        f = lambda *a: called.append(a)
        self.client.set_volume_new_generation_callback(f)

        # create the message
        message = protocol_pb2.Message()
        message.type = protocol_pb2.Message.VOLUME_NEW_GENERATION
        volume_id = uuid.uuid4()
        message.volume_new_generation.volume = str(volume_id)
        message.volume_new_generation.generation = 77

        # send the message, and assert the callback is called with good info
        self.client.handle_VOLUME_NEW_GENERATION(message)
        self.assertEqual(called[0], (volume_id, 77))

    def test_handle_volume_new_generation_root(self):
        """Test handle_VOLUME_NEW_GENERATION for ROOT."""
        # set the callback to record the info
        called = []
        f = lambda *a: called.append(a)
        self.client.set_volume_new_generation_callback(f)

        # create the message
        message = protocol_pb2.Message()
        message.type = protocol_pb2.Message.VOLUME_NEW_GENERATION
        message.volume_new_generation.volume = request.ROOT
        message.volume_new_generation.generation = 77

        # send the message, and assert the callback is called with good info
        self.client.handle_VOLUME_NEW_GENERATION(message)
        self.assertEqual(called[0], (request.ROOT, 77))

    # server to client
    def test_handle_volume_created(self):
        """Test handle_VOLUME_CREATED."""
        a_callback = was_called(self, 'called')
        self.client.set_volume_created_callback(a_callback)

        message = build_volume_created()
        set_root_message(message.volume_created)
        self.client.handle_VOLUME_CREATED(message)

        self.assertTrue(self.called)

    def test_handle_root_created_passes_a_root(self):
        """Test handle_VOLUME_CREATED parameter passing."""
        self.volume = None
        a_callback = lambda vol: setattr(self, 'volume', vol)
        self.client.set_volume_created_callback(a_callback)

        message = build_volume_created()
        set_root_message(message.volume_created)
        root = volumes.RootVolume.from_msg(message.volume_created.root)

        self.client.handle_VOLUME_CREATED(message)
        self.assertEquals(root, self.volume)

    def test_handle_udf_created_passes_a_udf(self):
        """Test handle_VOLUME_CREATED parameter passing."""
        self.volume = None
        a_callback = lambda vol: setattr(self, 'volume', vol)
        self.client.set_volume_created_callback(a_callback)

        message = build_volume_created()
        set_udf_message(message.volume_created)
        udf = volumes.UDFVolume.from_msg(message.volume_created.udf)

        self.client.handle_VOLUME_CREATED(message)
        self.assertEquals(udf, self.volume)

    def test_handle_share_created_passes_a_share(self):
        """Test handle_VOLUME_CREATED parameter passing."""
        self.volume = None
        a_callback = lambda vol: setattr(self, 'volume', vol)
        self.client.set_volume_created_callback(a_callback)

        message = build_volume_created()
        set_share_message(message.volume_created)
        share = volumes.ShareVolume.from_msg(message.volume_created.share)

        self.client.handle_VOLUME_CREATED(message)
        self.assertEquals(share, self.volume)

    def test_handle_volume_created_if_volume_is_buggy(self):
        """Test handle_VOLUME_CREATED if the volume is buggy."""
        message = build_volume_created()
        message.volume_created.type = -1  # invalid type!
        self.client.set_volume_created_callback(lambda vol: None)
        self.assertRaises(TypeError, self.client.handle_VOLUME_CREATED,
                          message)

    def test_handle_volume_created_if_callback_is_none(self):
        """Test handle_VOLUME_CREATED if callback is none."""
        message = build_volume_created()
        self.client.handle_VOLUME_CREATED(message)

    def test_handle_volume_deleted(self):
        """Test handle_VOLUME_DELETED."""
        a_callback = was_called(self, 'called')
        self.client.set_volume_deleted_callback(a_callback)

        message = build_volume_deleted()
        message.volume_deleted.volume = str(VOLUME)
        self.client.handle_VOLUME_DELETED(message)

        self.assertTrue(self.called)

    def test_handle_volume_deleted_passes_the_id(self):
        """Test handle_VOLUME_DELETED."""
        self.volume = None
        a_callback = lambda vol_id: setattr(self, 'volume', vol_id)
        self.client.set_volume_deleted_callback(a_callback)

        message = build_volume_deleted()
        message.volume_deleted.volume = str(VOLUME)
        self.client.handle_VOLUME_DELETED(message)

        self.assertEquals(VOLUME, self.volume)

    def test_handle_volume_deleted_if_none(self):
        """Test handle_VOLUME_DELETED if callback is none."""
        message = build_volume_deleted()
        self.client.handle_VOLUME_DELETED(message)


class RequestTestCase(TwistedTestCase):

    request_class = request.Request

    @defer.inlineCallbacks
    def setUp(self):
        yield super(RequestTestCase, self).setUp()
        import types
        self.request_class = types.ClassType(self.request_class.__name__,
                                             (self.request_class,), {})


class CreateUDFTestCase(RequestTestCase):
    """Test cases for CreateUDF op."""

    request_class = CreateUDF

    @defer.inlineCallbacks
    def setUp(self):
        """Initialize testing protocol."""
        yield super(CreateUDFTestCase, self).setUp()
        self.protocol = FakedProtocol()
        self.request = self.request_class(self.protocol, path=PATH, name=NAME)
        self.request.error = faked_error
        self.done_called = False
        self.request.done = was_called(self, 'done_called')

    def test_init(self):
        """Test request creation."""
        self.assertEquals(PATH, self.request.path)
        self.assertEquals(NAME, self.request.name)
        self.assertTrue(self.request.volume_id is None)
        self.assertTrue(self.request.node_id is None)

    def test_start(self):
        """Test request start."""
        self.request.start()

        self.assertEquals(1, len(self.request.protocol.messages))
        actual_msg, = self.request.protocol.messages
        self.assertEquals(protocol_pb2.Message.CREATE_UDF, actual_msg.type)
        self.assertEquals(self.request.path, actual_msg.create_udf.path)
        self.assertEquals(self.request.name, actual_msg.create_udf.name)

    def test_process_message_error(self):
        """Test request processMessage on error."""
        message = protocol_pb2.Message()
        self.assertRaises(FakedError, self.request.processMessage, message)

    def test_process_message_volume_created(self):
        """Test request processMessage on sucess."""
        message = protocol_pb2.Message()
        message.type = protocol_pb2.Message.VOLUME_CREATED
        message.volume_created.type = protocol_pb2.Volumes.UDF
        message.volume_created.udf.volume = str(VOLUME)
        message.volume_created.udf.node = str(NODE)
        self.request.processMessage(message)

        self.assertEquals(str(VOLUME), self.request.volume_id, 'volume set')
        self.assertEquals(str(NODE), self.request.node_id, 'node set')
        self.assertTrue(self.done_called, 'done() was called')


class ListVolumesTestCase(RequestTestCase):
    """Test cases for ListVolumes op."""

    request_class = ListVolumes

    @defer.inlineCallbacks
    def setUp(self):
        """Initialize testing protocol."""
        yield super(ListVolumesTestCase, self).setUp()
        self.protocol = FakedProtocol()
        self.request = self.request_class(self.protocol)
        self.request.error = faked_error
        self.done_called = False
        self.request.done = was_called(self, 'done_called')

    def test_init(self):
        """Test request creation."""
        self.assertEquals([], self.request.volumes)

    def test_start(self):
        """Test request start."""
        self.request.start()

        self.assertEquals(1, len(self.request.protocol.messages))
        actual_msg, = self.request.protocol.messages
        self.assertEquals(protocol_pb2.Message.LIST_VOLUMES, actual_msg.type)

    def test_process_message_error(self):
        """Test request processMessage on error."""
        message = protocol_pb2.Message()
        self.assertRaises(FakedError, self.request.processMessage, message)

    def test_process_message_error_when_incorrect_volume(self):
        """Test error condition when incorrect volume type."""
        message = build_list_volumes()
        message.list_volumes.type = -1
        self.assertRaises(FakedError, self.request.processMessage, message)

    def test_process_message_volume_created(self):
        """Test request processMessage on sucess."""
        message = build_list_volumes()
        set_udf_message(message.list_volumes)
        udf = volumes.UDFVolume.from_msg(message.list_volumes.udf)
        self.request.processMessage(message)

        message = build_list_volumes()
        set_share_message(message.list_volumes)
        share = volumes.ShareVolume.from_msg(message.list_volumes.share)
        self.request.processMessage(message)

        message = build_list_volumes()
        set_root_message(message.list_volumes)
        root = volumes.RootVolume.from_msg(message.list_volumes.root)
        self.request.processMessage(message)

        self.assertEquals(3, len(self.request.volumes),
                          '3 volumes stored')
        self.assertEquals([udf, share, root], self.request.volumes,
                          'volumes stored')

        message = protocol_pb2.Message()
        message.type = protocol_pb2.Message.VOLUMES_END
        self.request.processMessage(message)

        self.assertTrue(self.done_called, 'done() was called')

    def test_start_cleanups_volumes(self):
        """Test start() is idempotent."""
        self.request.start()

        message = build_list_volumes()
        set_udf_message(message.list_volumes)
        self.request.processMessage(message)

        message = protocol_pb2.Message()
        message.type = protocol_pb2.Message.VOLUMES_END
        self.request.processMessage(message)

        self.request.start()
        self.assertEquals([], self.request.volumes)


class DeleteVolumeTestCase(RequestTestCase):
    """Test cases for DeleteVolume op."""

    request_class = DeleteVolume

    @defer.inlineCallbacks
    def setUp(self):
        """Initialize testing protocol."""
        yield super(DeleteVolumeTestCase, self).setUp()
        self.protocol = FakedProtocol()
        self.request = self.request_class(self.protocol, volume_id=VOLUME)
        self.request.error = faked_error
        self.done_called = False
        self.request.done = was_called(self, 'done_called')

    def test_init(self):
        """Test request creation."""
        self.assertEquals(str(VOLUME), self.request.volume_id)

    def test_start(self):
        """Test request start."""
        self.request.start()

        self.assertEquals(1, len(self.request.protocol.messages))
        actual_msg, = self.request.protocol.messages
        self.assertEquals(protocol_pb2.Message.DELETE_VOLUME, actual_msg.type)
        self.assertEquals(self.request.volume_id,
                          actual_msg.delete_volume.volume)

    def test_process_message_error(self):
        """Test request processMessage on error."""
        message = protocol_pb2.Message()
        self.assertRaises(FakedError, self.request.processMessage, message)

    def test_process_message_ok(self):
        """Test request processMessage on sucess."""
        message = protocol_pb2.Message()
        message.type = protocol_pb2.Message.OK
        self.request.processMessage(message)

        self.assertTrue(self.done_called, 'done() was called')


class GetDeltaTestCase(RequestTestCase):
    """Test cases for GetDelta op."""

    request_class = GetDelta

    @defer.inlineCallbacks
    def setUp(self):
        """Initialize testing protocol."""
        yield super(GetDeltaTestCase, self).setUp()
        self.protocol = FakedProtocol()
        self.request = self.request_class(self.protocol, SHARE, 0)
        self.request.error = faked_error
        self.done_called = False
        self.request.done = was_called(self, 'done_called')

    def test_init(self):
        """Test request creation."""
        self.assertEquals(str(SHARE), self.request.share_id)

    def test_start(self):
        """Test request start."""
        self.request.start()

        self.assertEquals(1, len(self.request.protocol.messages))
        actual_msg, = self.request.protocol.messages
        self.assertEquals(protocol_pb2.Message.GET_DELTA, actual_msg.type)
        self.assertEquals(self.request.share_id,
                          actual_msg.get_delta.share)

    def test_process_message_error(self):
        """Test request processMessage on error."""
        message = protocol_pb2.Message()
        self.assertRaises(FakedError, self.request.processMessage, message)

    def test_process_message_ok(self):
        """Test request processMessage on sucess."""
        message = protocol_pb2.Message()
        message.type = protocol_pb2.Message.DELTA_END
        message.delta_end.generation = 100
        message.delta_end.full = True
        message.delta_end.free_bytes = 200
        self.request.processMessage(message)

        self.assertTrue(self.done_called, 'done() was called')
        self.assertEquals(self.request.end_generation,
                          message.delta_end.generation)
        self.assertEquals(self.request.full, message.delta_end.full)
        self.assertEquals(self.request.free_bytes,
                          message.delta_end.free_bytes)

    def test_process_message_content(self):
        """Test request processMessage for content."""
        message = test_delta_info.get_message()
        self.request.processMessage(message)
        self.assertTrue(delta.from_message(message) in self.request.response)

    def test_process_message_content_twice(self):
        """Test request processMessage for content."""
        message = test_delta_info.get_message()
        self.request.processMessage(message)
        message = test_delta_info.get_message()
        self.request.processMessage(message)
        self.assertEqual(len(self.request.response), 2)

    def test_process_message_content_callback(self):
        """Test request processMessage for content w/callback."""
        response = []
        self.request = GetDelta(self.protocol, SHARE, 0,
                                callback=response.append)
        message = test_delta_info.get_message()
        self.request.processMessage(message)
        self.assertTrue(delta.from_message(message) in response)

    def test_from_scratch_flag(self):
        """Test from scratch flag."""
        self.request = self.request_class(self.protocol, SHARE, 0,
                                          from_scratch=True)
        self.request.done = was_called(self, 'done_called')
        self.request.start()

        self.assertEquals(1, len(self.request.protocol.messages))
        actual_msg, = self.request.protocol.messages
        self.assertEquals(protocol_pb2.Message.GET_DELTA, actual_msg.type)
        self.assertEquals(self.request.share_id,
                          actual_msg.get_delta.share)


class TestAuth(RequestTestCase):
    """Tests the authentication request."""

    request_class = Authenticate

    def test_session_id(self):
        """Test that the request has the session id attribute."""
        SESSION_ID = "opaque_session_id"
        req = self.request_class(FakedProtocol(), None)
        req.done = MethodMock()
        message = protocol_pb2.Message()
        message.type = protocol_pb2.Message.AUTH_AUTHENTICATED
        message.session_id = SESSION_ID
        req.processMessage(message)
        self.assert_(req.done.called)
        self.assertEqual(req.session_id, SESSION_ID)

    def test_with_metadata(self):
        """Test with optional metadata."""
        protocol = FakedProtocol()
        protocol.dummy_authenticate('my_token',
                                    metadata={'version': '0.1',
                                              'platform': sys.platform})
        msgs = protocol.messages
        self.assertTrue(len(msgs), 1)
        msg = msgs[0]
        self.assertEqual(len(msg.auth_parameters), 1)
        self.assertEqual(len(msg.metadata), 2)
        metadata = {'version': '0.1', 'platform': sys.platform}
        for md in msg.metadata:
            self.assertTrue(md.key in metadata)
            self.assertEqual(md.value, metadata[md.key])

    def test_without_metadata(self):
        """Test without optional metadata."""
        protocol = FakedProtocol()
        protocol.dummy_authenticate('my_token')
        msgs = protocol.messages
        self.assertTrue(len(msgs), 1)
        msg = msgs[0]
        self.assertEqual(len(msg.auth_parameters), 1)
        self.assertEqual(len(msg.metadata), 0)

    def test_oauth_authenticate_uses_server_timestamp(self):
        """The oauth authentication uses the server timestamp."""
        fromcandt_call = []

        fake_token = oauth.OAuthToken('token', 'token_secret')
        fake_consumer = oauth.OAuthConsumer('consumer_key', 'consumer_secret')

        fake_timestamp = object()
        timestamp_d = Deferred()
        self.patch(tx_timestamp_checker, "get_faithful_time",
                   lambda: timestamp_d)
        original_fromcandt = oauth.OAuthRequest.from_consumer_and_token

        @staticmethod
        def fake_from_consumer_and_token(**kwargs):
            """A fake from_consumer_and_token."""
            fromcandt_call.append(kwargs)
            return original_fromcandt(**kwargs)

        self.patch(oauth.OAuthRequest, "from_consumer_and_token",
                   fake_from_consumer_and_token)
        protocol = FakedProtocol()
        protocol.oauth_authenticate(fake_consumer, fake_token)
        self.assertEqual(len(fromcandt_call), 0)
        timestamp_d.callback(fake_timestamp)
        parameters = fromcandt_call[0]["parameters"]
        self.assertEqual(parameters["oauth_timestamp"], fake_timestamp)


class RootResource(resource.Resource):
    """A root resource that logs the number of calls."""

    isLeaf = True

    def __init__(self, *args, **kwargs):
        """Initialize this fake instance."""
        self.count = 0
        self.request_headers = []

    def render_HEAD(self, request):
        """Increase the counter on each render."""
        self.request_headers.append(request.requestHeaders)
        self.count += 1
        return ""


class MockWebServer(object):
    """A mock webserver for testing."""

    def __init__(self):
        """Start up this instance."""
        self.root = RootResource()
        site = server.Site(self.root)
        application = service.Application('web')
        self.service_collection = service.IServiceCollection(application)
        self.tcpserver = internet.TCPServer(0, site)
        self.tcpserver.setServiceParent(self.service_collection)
        self.service_collection.startService()

    def get_url(self):
        """Build the url for this mock server."""
        port_num = self.tcpserver._port.getHost().port
        return "http://localhost:%d/" % port_num

    def stop(self):
        """Shut it down."""
        self.service_collection.stopService()


class TimestampCheckerTestCase(TwistedTestCase):
    """Tests for the timestamp checker."""

    @inlineCallbacks
    def setUp(self):
        """Initialize a fake webserver."""
        yield super(TimestampCheckerTestCase, self).setUp()
        self.ws = MockWebServer()
        self.addCleanup(self.ws.stop)
        self.patch(TwistedTimestampChecker, "SERVER_URL", self.ws.get_url())

    @inlineCallbacks
    def test_returned_value_is_int(self):
        """The returned value is an integer."""
        checker = TwistedTimestampChecker()
        t = yield checker.get_faithful_time()
        self.assertEqual(type(t), int)

    @inlineCallbacks
    def test_first_call_does_head(self):
        """The first call gets the clock from our web."""
        checker = TwistedTimestampChecker()
        yield checker.get_faithful_time()
        self.assertEqual(self.ws.root.count, 1)

    @inlineCallbacks
    def test_second_call_is_cached(self):
        """For the second call, the time is cached."""
        checker = TwistedTimestampChecker()
        yield checker.get_faithful_time()
        yield checker.get_faithful_time()
        self.assertEqual(self.ws.root.count, 1)

    @inlineCallbacks
    def test_after_timeout_cache_expires(self):
        """After some time, the cache expires."""
        fake_timestamp = 1
        self.patch(time, "time", lambda: fake_timestamp)
        checker = TwistedTimestampChecker()
        yield checker.get_faithful_time()
        fake_timestamp += TwistedTimestampChecker.CHECKING_INTERVAL
        yield checker.get_faithful_time()
        self.assertEqual(self.ws.root.count, 2)

    @inlineCallbacks
    def test_server_error_means_skew_not_updated(self):
        """When server can't be reached, the skew is not updated."""
        fake_timestamp = 1
        self.patch(time, "time", lambda: fake_timestamp)
        checker = TwistedTimestampChecker()
        failing_get_server_time = lambda: defer.fail(FakedError())
        self.patch(checker, "get_server_time", failing_get_server_time)
        yield checker.get_faithful_time()
        self.assertEqual(checker.skew, 0)
        self.assertEqual(checker.next_check,
                    fake_timestamp + TwistedTimestampChecker.ERROR_INTERVAL)

    @inlineCallbacks
    def test_server_date_sends_nocache_headers(self):
        """Getting the server date sends the no-cache headers."""
        checker = TwistedTimestampChecker()
        yield checker.get_server_date_header(self.ws.get_url())
        assert len(self.ws.root.request_headers) == 1
        headers = self.ws.root.request_headers[0]
        result = headers.getRawHeaders("Cache-Control")
        self.assertEqual(result, ["no-cache"])


class TestGenerationInRequests(RequestTestCase):
    """Base class for testing that actions that change the volume will
    have a new_generation attribute set."""

    request_class = MakeFile

    def build_request(self):
        """Creates the request object."""
        return self.request_class(FakedProtocol(), None, None, "name")

    def build_message(self):
        """Creates the ending message for the request."""
        message = protocol_pb2.Message()
        message.type = protocol_pb2.Message.NEW_FILE
        message.new_generation = GENERATION
        return message

    def test_make(self):
        """Test the request for new_generation."""
        req = self.build_request()
        req.done = MethodMock()
        message = self.build_message()
        req.processMessage(message)
        self.assert_(req.done.called)
        self.assertEqual(req.new_generation, GENERATION)


class TestGenerationInRequestsMakeDir(TestGenerationInRequests):
    """Tests for new_generation in MakeDir."""

    request_class = MakeDir

    def build_message(self):
        """Creates the ending message for the request."""
        message = protocol_pb2.Message()
        message.type = protocol_pb2.Message.NEW_DIR
        message.new_generation = GENERATION
        return message


class TestGenerationInRequestsPutContent(TestGenerationInRequests):
    """Tests for new_generation in PutContent."""

    request_class = PutContent

    def build_request(self):
        """Creates the request object."""
        return self.request_class(FakedProtocol(), None, None, None, None,
                          None, None, None, None)

    def build_message(self):
        """Creates the ending message for the request."""
        message = protocol_pb2.Message()
        message.type = protocol_pb2.Message.OK
        message.new_generation = GENERATION
        return message


class TestGenerationInRequestsUnlink(TestGenerationInRequestsPutContent):
    """Tests for new_generation in Unlink."""

    request_class = Unlink

    def build_request(self):
        """Creates the request object."""
        return self.request_class(FakedProtocol(), None, None)


class TestGenerationInRequestsMove(TestGenerationInRequestsPutContent):
    """Tests for new_generation in Move."""

    request_class = Move

    def build_request(self):
        """Creates the request object."""
        return self.request_class(FakedProtocol(), None, None, None, None)


class PutContentTestCase(RequestTestCase):
    """Test cases for PutContent op."""

    request_class = PutContent

    def test_max_payload_size(self):
        """Get the value from the protocol."""
        self.protocol = FakedProtocol()
        assert 12345 != self.protocol.max_payload_size
        self.protocol.max_payload_size = 12345
        pc = PutContent(self.protocol, None, None, None, None,
                        None, None, None, None)
        self.assertEqual(pc.max_payload_size, 12345)

    def test_bytesmessageproducer_maxpayloadsize(self):
        """The producer uses the payload size from the request."""
        # set up the PutContent
        pc = self.request_class(FakedProtocol(), None, None, None, None,
                        None, None, None, None)
        assert 12345 != pc.max_payload_size
        pc.max_payload_size = 12345

        # set up the producer
        fake_file = StringIO.StringIO(os.urandom(100000))
        producer = BytesMessageProducer(pc, fake_file, 0)
        producer.producing = True

        # set up a function to check and go!
        d = Deferred()

        def check(message):
            """Check."""
            self.assertEqual(len(message.bytes.bytes), 12345)
            producer.producing = False
            d.callback(True)

        pc.sendMessage = check
        producer.go()
        return d


if __name__ == '__main__':
    unittest.main()
