# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.

# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
"""tests for specific behaviour of astng scoped nodes (ie module, class and
function)

Copyright (c) 2003-2005 LOGILAB S.A. (Paris, FRANCE).
http://www.logilab.fr/ -- mailto:contact@logilab.fr
"""

__revision__ = "$Id: unittest_scoped_nodes.py,v 1.4 2005/11/02 11:56:49 syt Exp $"

import unittest
import sys

from logilab.astng import builder, nodes, scoped_nodes, \
     ResolveError, NotFoundError, ASTNGManager

manager = ASTNGManager()
abuilder = builder.ASTNGBuilder(manager) 
MODULE = abuilder.file_build('data/module.py', 'data.module')
MODULE2 = abuilder.file_build('data/module2.py', 'data.module2')
NONREGR = abuilder.file_build('data/nonregr.py', 'data.nonregr')

def _test_dict_interface(self, node, test_attr):
    self.assert_(node[test_attr] is node[test_attr])
    self.assert_(test_attr in node)
    node.keys()
    node.values()
    node.items()
    iter(node)


class ModuleNodeTC(unittest.TestCase):

    def test_dict_interface(self):
        _test_dict_interface(self, MODULE, 'YO')
        
    def test_resolve(self):
        yo = MODULE.resolve('YO')
        self.assert_(isinstance(yo, nodes.Class))
        self.assertEquals(yo.name, 'YO')        
        red = MODULE.resolve('redirect')
        self.assert_(isinstance(red, nodes.Function))
        self.assertEquals(red.name, 'nested_args')
        spawn = MODULE.resolve('spawn')
        self.assert_(isinstance(spawn, nodes.Class))
        self.assertEquals(spawn.name, 'Execute')
        # built-in objects
        none = MODULE.resolve('None')
        self.assertEquals(none.value, None)
        obj = MODULE.resolve('object')
        self.assert_(isinstance(obj, nodes.Class))
        self.assertEquals(obj.name, 'object')
        self.assertEquals(NONREGR.resolve('enumerate').name, 'enumerate')
        # resolve packageredirection
        sys.path.insert(1, 'data')
        try:
            m = abuilder.file_build('data/appl/myConnection.py', 'appl.myConnection')
            cnx = m.resolve_dotted('SSL1.Connection')
            self.assertEquals(cnx.__class__, nodes.Class)
            self.assertEquals(cnx.name, 'Connection')
            self.assertEquals(cnx.root().name, 'Connection1')
        finally:
            del sys.path[1]
        # raise ResolveError
        self.assertRaises(ResolveError, MODULE.resolve, 'YOAA')

        
    def test_resolve_dotted(self):
        exists = MODULE.resolve_dotted('os.path.exists')
        self.assert_(isinstance(exists, nodes.Function))
        self.assertEquals(exists.name, 'exists')
        self.assertRaises(ResolveError, NONREGR['toto'].resolve_dotted, 'v.get')
        # resolve sub package
        mod = MODULE2.resolve_dotted('data.module')
                          
    def test_resolve_all(self):
        exists, = MODULE.resolve_all(['os.path.exists', 'v.get'])
        self.assert_(isinstance(exists, nodes.Function))
        self.assertEquals(exists.name, 'exists')
                          
    def test_wildard_import_names(self):
        m = abuilder.file_build('data/all.py', 'all')
        self.assertEquals(m.wildcard_import_names(), ['Aaa', '_bla', 'name'])
        m = abuilder.file_build('data/notall.py', 'notall')
        res = m.wildcard_import_names()
        res.sort()
        self.assertEquals(res, ['Aaa', 'func', 'name', 'other'])
        
    def test_as_string(self):
        """just check as_string on a whole module doesn't raise an exception
        """
        self.assert_(MODULE.as_string())
        self.assert_(MODULE2.as_string())
        
        
class FunctionNodeTC(unittest.TestCase):

    def test_dict_interface(self):
        _test_dict_interface(self, MODULE['global_access'], 'local')
        
    def test_resolve(self):
        method = MODULE['YOUPI']['method']
        my_dict = method.resolve('MY_DICT')
        self.assert_(isinstance(my_dict, nodes.Dict))
        none = method.resolve('None')
        self.assertEquals(none.value, None)
        self.assertRaises(ResolveError, method.resolve, 'YOAA')
        
    def test_resolve_argument_with_default(self):
        make_class = MODULE2['make_class']
        base = make_class.resolve('base')
        self.assert_(isinstance(base, nodes.Class), base.__class__)
        self.assertEquals(base.name, 'YO')
        self.assertEquals(base.root().name, 'data.module')

    def test_default_value(self):
        func = MODULE2['make_class']
        self.assert_(isinstance(func.default_value('base'), nodes.Getattr))
        self.assert_(isinstance(func.default_value('args'), nodes.Tuple))
        self.assert_(isinstance(func.default_value('kwargs'), nodes.Dict))
        self.assertRaises(scoped_nodes.NoDefault, func.default_value, 'any')

    def test_navigation(self):
        function = MODULE['global_access']
        self.assertEquals(function.statement(), function)
        l_sibling = function.previous_sibling()
        self.assert_(isinstance(l_sibling, nodes.Assign))
        self.assert_(l_sibling is function.getChildNodes()[0].previous_sibling())
        r_sibling = function.next_sibling()
        self.assert_(isinstance(r_sibling, nodes.Class))
        self.assertEquals(r_sibling.name, 'YO')
        self.assert_(r_sibling is function.getChildNodes()[0].next_sibling())
        last = r_sibling.next_sibling().next_sibling().next_sibling()
        self.assert_(isinstance(last, nodes.Assign))
        self.assertEquals(last.next_sibling(), None)
        first = l_sibling.previous_sibling().previous_sibling().previous_sibling().previous_sibling()
        self.assertEquals(first.previous_sibling(), None)

    def test_nested_args(self):
        func = MODULE['nested_args']
        self.assertEquals(func.argnames, ['a', ('b', 'c', 'd')])
        local = func.keys()
        local.sort()
        self.assertEquals(local, ['a', 'b', 'c', 'd'])
        self.assertEquals(func.type, 'function')
       
    def test_format_args(self):
        func = MODULE2['make_class']
        self.assertEquals(func.format_args(), 'any, base=data.module.YO, *args, **kwargs')
        func = MODULE['nested_args']
        self.assertEquals(func.format_args(), 'a, (b,c,d)')

    def test_is_abstract(self):
        method = MODULE2['AbstractClass']['to_override']
        self.assert_(method.is_abstract(pass_is_abstract=False))
        method = MODULE2['AbstractClass']['return_something']
        self.assert_(not method.is_abstract(pass_is_abstract=False))
        # non regression : test raise "string" doesn't cause an exception in is_abstract
        func = MODULE2['raise_string']
        self.assert_(not func.is_abstract(pass_is_abstract=False))
        
##     def test_raises(self):
##         method = MODULE2['AbstractClass']['to_override']
##         self.assertEquals([str(term) for term in method.raises()],
##                           ["CallFunc(Name('NotImplementedError'), [], None, None)"] )
        
##     def test_returns(self):
##         method = MODULE2['AbstractClass']['return_something']
##         # use string comp since Node doesn't handle __cmp__ 
##         self.assertEquals([str(term) for term in method.returns()],
##                           ["Const('toto')", "Const(None)"])

        
class ClassNodeTC(unittest.TestCase):

    def test_dict_interface(self):
        _test_dict_interface(self, MODULE['YOUPI'], 'method')
        
    def test_resolve(self):
        klass = MODULE['YOUPI']
        my_dict = klass.resolve('MY_DICT')
        self.assert_(isinstance(my_dict, nodes.Dict))
        none = klass.resolve('None')
        self.assertEquals(none.value, None)
        obj = klass.resolve('object')
        self.assert_(isinstance(obj, nodes.Class))
        self.assertEquals(obj.name, 'object')
        self.assertRaises(ResolveError, klass.resolve, 'YOAA')

    def test_navigation(self):
        klass = MODULE['YO']
        self.assertEquals(klass.statement(), klass)
        l_sibling = klass.previous_sibling()
        self.assert_(isinstance(l_sibling, nodes.Function), l_sibling)
        self.assertEquals(l_sibling.name, 'global_access')
        r_sibling = klass.next_sibling()
        self.assert_(isinstance(r_sibling, nodes.Class))
        self.assertEquals(r_sibling.name, 'YOUPI')
        
    def test_local_attr_ancestors(self):
        klass2 = MODULE['YOUPI']
        it = klass2.local_attr_ancestors('__init__')
        anc_klass = it.next()
        self.assert_(isinstance(anc_klass, nodes.Class))
        self.assertEquals(anc_klass.name, 'YO')
        self.assertRaises(StopIteration, it.next)
        it = klass2.local_attr_ancestors('method')
        self.assertRaises(StopIteration, it.next)

    def test_instance_attr_ancestors(self):
        klass2 = MODULE['YOUPI']
        it = klass2.instance_attr_ancestors('yo')
        anc_klass = it.next()
        self.assert_(isinstance(anc_klass, nodes.Class))
        self.assertEquals(anc_klass.name, 'YO')
        self.assertRaises(StopIteration, it.next)
        klass2 = MODULE['YOUPI']
        it = klass2.instance_attr_ancestors('member')
        self.assertRaises(StopIteration, it.next)
        
    def test_methods(self):
        klass2 = MODULE['YOUPI']
        methods = [m.name for m in klass2.methods()]
        methods.sort()
        self.assertEquals(methods, ['__init__', 'class_method',
                                   'method', 'static_method'])
        methods = [m.name for m in klass2.mymethods()]
        methods.sort()
        self.assertEquals(methods, ['__init__', 'class_method',
                                   'method', 'static_method'])
        klass2 = MODULE2['Specialization']
        methods = [m.name for m in klass2.mymethods()]
        methods.sort()
        self.assertEquals(methods, [])
        self.assertEquals(klass2.local_attr('method').name, 'method')
        self.assertRaises(NotFoundError, klass2.local_attr, 'nonexistant')
        methods = [m.name for m in klass2.methods()]
        methods.sort()
        self.assertEquals(methods, ['__init__', 'class_method',
                                   'method', 'static_method'])
        
    def test_rhs(self):
        my_dict = MODULE['MY_DICT']
        self.assert_(isinstance(my_dict.rhs(), nodes.Dict))
        a = MODULE['YO']['a']
        value = a.rhs()
        self.assert_(isinstance(value, nodes.Const))
        self.assertEquals(value.value, 1)
        
    def test_ancestors(self):
        klass = MODULE['YOUPI']
        ancs = [a.name for a in klass.ancestors()]
        self.assertEquals(ancs, ['YO'])
        klass = MODULE2['Specialization']
        ancs = [a.name for a in klass.ancestors()]
        self.assertEquals(ancs, ['YOUPI', 'YO', 'YO'])

    def test_type(self):
        klass = MODULE['YOUPI']
        self.assertEquals(klass.type, 'class')
        klass = MODULE2['Metaclass']
        self.assertEquals(klass.type, 'metaclass')
        klass = MODULE2['MyException']
        self.assertEquals(klass.type, 'exception')
        klass = MODULE2['MyIFace']
        self.assertEquals(klass.type, 'interface')
        klass = MODULE2['MyError']
        self.assertEquals(klass.type, 'exception')

    def test_interfaces(self):
        for klass, interfaces in (('Concrete0', ['MyIFace']),
                                  ('Concrete1', ['MyIFace', 'AnotherIFace']),
                                  ('Concrete2', ['MyIFace', 'AnotherIFace']),
                                  ('Concrete23', ['MyIFace', 'AnotherIFace'])):
            klass = MODULE2[klass]
            self.assertEquals([i.name for i in klass.interfaces()],
                              interfaces)
        
    def test_inner_classes(self):
        ccc = NONREGR['Ccc']
        eee = NONREGR['Ccc']['Eee']        
        self.assertEquals(ccc.resolve('Ddd').name, 'Ddd')
        self.assertEquals([n.name for n in eee.ancestors()], ['Ddd', 'Aaa', 'object'])
    
__all__ = ('ModuleNodeTC', 'ImportNodeTC', 'FunctionNodeTC', 'ClassNodeTC')
        
if __name__ == '__main__':
    unittest.main()
