"""
Utilities for PyTables' test suites
===================================

:Author:   Ivan Vilata i Balaguer
:Contact:  ivilata@carabos.com
:Created:  2005-05-24
:License:  BSD
:Revision: $Id$
"""

import unittest
import tempfile
import os

import numarray
import numarray.strings
import numarray.records

import tables
from test_all import verbose



def verbosePrint(string):
    """Print out the `string` if verbose output is enabled."""
    if verbose: print string



def areArraysEqual(arr1, arr2):
    """
    Are both `arr1` and `arr2` equal arrays?

    Arguments can be regular Numarray arrays, CharArray arrays or record
    arrays and its descendants (i.e. nested record arrays).  They are
    checked for type and value equality.
    """

    t1 = type(arr1)
    t2 = type(arr2)

    if not ((t1 is t2) or issubclass(t1, t2) or issubclass(t2, t1)):
        return False

    if isinstance(arr1, tables.nestedrecords.NestedRecArray):
        arr1 = arr1.asRecArray()
    if isinstance(arr2, tables.nestedrecords.NestedRecArray):
        arr2 = arr2.asRecArray()
    if isinstance(arr1, tables.nestedrecords.NestedRecord):
        row = arr1.row
        arr1 = arr1.array[row:row+1]
    if isinstance(arr2, tables.nestedrecords.NestedRecord):
        row = arr2.row
        arr2 = arr2.array[row:row+1]

    if isinstance(arr1, numarray.records.RecArray):
        arr1Names = arr1._names
        arr2Names = arr2._names
        if arr1Names != arr2Names:
            return False
        for fieldName in arr1Names:
            if not areArraysEqual(arr1.field(fieldName),
                                  arr2.field(fieldName)):
                return False
        return True

    if isinstance(arr1, numarray.NumArray):
        if arr1.shape != arr2.shape:
            return False
        if arr1.type() != arr2.type():
            return False
        return numarray.alltrue(arr1.flat == arr2.flat)

    if isinstance(arr1, numarray.strings.CharArray):
        if arr1.shape != arr2.shape:
            return False
        if arr1._type != arr2._type:
            return False
        for i in range(len(arr1)):
            if not arr1[i] == arr2[i]:
                return False
        return True


class PyTablesTestCase(unittest.TestCase):

    """Abstract test case with useful methods."""

    def _getName(self):
        """Get the name of this test case."""
        return self.id().split('.')[-2]


    def _getMethodName(self):
        """Get the name of the method currently running in the test case."""
        return self.id().split('.')[-1]


    def _verboseHeader(self):
        """Print a nice header for the current test method if verbose."""

        if verbose:
            name = self._getName()
            methodName = self._getMethodName()

            print '\n', '-=' * 30
            print "Running %s.%s..." % (name, methodName)



class TempFileMixin:
    def setUp(self):
        """
        Set ``h5file`` and ``h5fname`` instance attributes.

        * ``h5fname``: the name of the temporary HDF5 file.
        * ``h5file``: the writable, empty, temporary HDF5 file.
        """

        self.h5fname = tempfile.mktemp(suffix='.h5')
        self.h5file = tables.openFile(
            self.h5fname, 'w', title=self._getName())


    def tearDown(self):
        """Close ``h5file`` and remove ``h5fname``."""

        self.h5file.close()
        self.h5file = None
        os.remove(self.h5fname)


    def _reopen(self, mode='r'):
        """Reopen ``h5file`` in the specified ``mode``."""

        self.h5file.close()
        self.h5file = tables.openFile(self.h5fname, mode)



## Local Variables:
## mode: python
## py-indent-offset: 4
## tab-width: 4
## fill-column: 72
## End:
