############################################################################
##
## Copyright (c) 2000, 2001, 2002 BalaBit IT Ltd, Budapest, Hungary
##
## This program is free software; you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation; either version 2 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program; if not, write to the Free Software
## Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
##
##
## $Id: Dispatch.py,v 1.8.2.11 2003/10/01 12:44:24 bazsi Exp $
##
## Author  : Bazsi
## Auditor : 
## Last audited version: 1.15
## Notes:
##
############################################################################

from Zorp import *
from Session import MasterSession
from Cache import ShiftCache
from traceback import print_exc
from string import atoi
import Zorp

listen_hook = None
unlisten_hook = None
zone_dispatcher_shift_threshold = 1000

client_ips = {}

ZD_PRI_LISTEN = 100
ZD_PRI_NORMAL = 0
ZD_PRI_RELATED = -100


class AbstractDispatch:
	def __init__(self, session_id, protocol, bindto, kw=None):
		self.session_id = session_id
		prio = getKeywordArg(kw, 'prio', ZD_PRI_LISTEN)
		self.dispatch = Dispatch(self.session_id, protocol, bindto, prio, self.accepted, kw)
		if kw == None:
			kw = {}
		Globals.dispatches.append(self)

	def accepted(self):
		"""Function called when a connection is established.
		
		This function is called when a connection is established. 
		It does nothing here, it should be overridden by descendant
		classes.
		
		Arguments
		
		  self -- this instance
		"""
		return Z_REJECT

	def destroy(self):
		"""Stops the listener on the given port

		Calls the destroy method of the low-level object

		Arguments
		
		  self -- this instance
		  
		"""
		self.dispatch.destroy()


class Dispatcher(AbstractDispatch):
	"""Class encapsulating a Listener, which starts services for established connections.
	
	This class is the starting point of Zorp services. It listens on the
	given port, and when a connection is accepted it starts a session
	and the given service.
                   
	Attributes
	
	  listen --       A Zorp.Listen instance
	  
	  service --      the service to be started
	  
	  bindto --       bind address
	  
	  local --        local address where the listener is bound

	  protocol --     the protocol we were bound to
	"""

	def __init__(self, protocol, bindto, service, kw=None):
		"""Constructor to initialize a Listen instance
		
		Creates the instance, sets the initial attributes, and
		starts the listener

		Arguments

		  self -- this instance
		  
		  bindto --  the address to bind to
		  
		  service -- the service name to start
		  
		  transparent -- TRUE if this is a listener of a 
		                 transparent service, specifying this is
		                 not mandatory but performs additional checks

		"""
		try:
			if service != None:
				self.service = Globals.services[service]
			else:
				self.service = None
		except KeyError:
			raise ServiceException, "Service %s not found" % (service,)
		self.bindto = bindto
		self.transparent = getKeywordArg(kw, 'transparent', FALSE)
		self.protocol = protocol
	        AbstractDispatch.__init__(self, Zorp.firewall_name, protocol, bindto, kw=kw)
	        
	def accepted(self, stream, client_address, client_local, client_listen):
		"""Callback to inform the python layer about incoming connections.
		
		This callback is called by the core when a connection is 
		accepted. Its primary function is to check access control
		(whether the client is permitted to connect to this port),
		and to spawn a new session to handle the connection.
		
		Exceptions raised due to policy violations are handled here.

		Arguments

		  self   --  this instance
		  
		  stream --  the stream of the connection to the client

		  client_address --  the address of the client

		  client_local -- client local address (contains the original destination if transparent)

		  client_listen -- the address where the listener was bound to
		  
		Returns
		
		  TRUE if the connection is accepted
		"""
		global client_ips
		
		if stream == None:
			return None
		session = None
		try:
			session = MasterSession(self.session_id)
			session.protocol = self.protocol
			stream.name = session.session_id
			session.setClient(stream, client_address)
			session.client_local = client_local
			session.client_listen = client_listen
			service = self.getService(session)
			if not service:
				raise DACException, "No applicable service found"
			session.setService(service)
			
			service.router.routeConnection(session)

			if self.transparent and client_local.port == client_listen.port:
 				log(session.session_id, CORE_ERROR, 1, "Transparent listener connected directly, dropping connection; local='%s', client_local='%s'" % (session.client_listen, session.client_local))
			elif session.isClientPermitted() == Z_ACCEPT:
 				log(session.session_id, CORE_DEBUG, 8, "Connection accepted; client_address='%s'" % (client_address))
				return session.service.startInstance(session)
			raise DACException, "This service was not permitted outbound"
		except ZoneException, s:
 			log(session.session_id, CORE_POLICY, 1, "Zone not found; zone='%s'" % (s,))
		except DACException, s:
 			log(session.session_id, CORE_POLICY, 1, "DAC policy violation; info='%s'" % (s,))
		except MACException:
 			log(session.session_id, CORE_POLICY, 1, "MAC policy violation;")
		except AuthException:
 			log(session.session_id, CORE_POLICY, 1, "Authentication failure;")
		except LimitException:
 			log(session.session_id, CORE_POLICY, 1, "Connection over permitted limits;")
		except LicenseException:
			log(session.session_id, CORE_POLICY, 1, "Attempt to use an unlicensed component, or number of licensed hosts exceeded;")
		except:
			print_exc()
			
		if session != None: 
			session.destroy()

		return None

	def getService(self, session):
		"""Returns the service associated with the listener

		Returns the service to start.

		Arguments

		  self    -- this instance
		  
		  session -- session reference
		"""
		return self.service


class ZoneDispatcher(Dispatcher):
	"""Class to listen on the selected address, and start a service based on the client's zone.
	
	This class is similar to a simple Dispatcher, but instead of
	starting a fixed service, it chooses one based on the client
	zone.
	
	It takes a mapping of services indexed by a zone name, with
	an exception of the '*' service, which matches anything.
	
	Attributes
	
	  services -- services mapping indexed by zone name
	"""
	
	def __init__(self, protocol, bindto, services, kw=None):
		"""Constructor to initialize a ZoneDispatcher instance.
		
		This constructor initializes a ZoneDispatcher instance and sets
		its initial attributes based on arguments.
		
		Arguments
		 
		  self -- this instance
		  
		  bindto -- bind to this address
		  
		  services -- a mapping between zone names and services

		  follow_parent -- whether to follow the administrative hieararchy when finding the correct service

		"""
		self.follow_parent = getKeywordArg(kw, 'follow_parent', FALSE)
		Dispatcher.__init__(self, protocol, bindto, None, kw=kw)
		self.services = services
		self.cache = ShiftCache('sdispatch(%s)' % bindto, zone_dispatcher_shift_threshold)
	
	def getService(self, session):
		"""Virtual function which returns the service to be ran
		
		This function is called by our base class to find out the
		service to be used for the current session. It uses the
		client zone name to decide which service to use.
		
		Arguments
		
		  self -- this instance
		  
		  session -- session we are starting
		  
		"""

		cache_ndx = session.client_zone.getName()

		try:
			cached = self.cache.lookup(cache_ndx)
			if not cached:
				log(None, CORE_POLICY, 2, "No applicable service found for this client zone (cached); bindto='%s', client_zone='%s'" % (self.bindto, session.client_zone))
			return cached
		except KeyError:
			pass

		src_hierarchy = {}
		if self.follow_parent:
			z = session.client_zone
			level = 0
			while z:
				src_hierarchy[z.getName()] = level
				z = z.admin_parent
				level = level + 1
			src_hierarchy['*'] = level
			max_level = level + 1
		else:
			src_hierarchy[session.client_zone.getName()] = 0
			src_hierarchy['*'] = 1
			max_level = 10

		best = None
		for spec in self.services.keys():
			try:
				src_level = src_hierarchy[spec]
			except KeyError:
				src_level = max_level
				
			if not best or 							\
			   (best_src_level > src_level):
				best = self.services[spec]
				best_src_level = src_level

		if best_src_level < max_level:
			s = Globals.services[best]
		else:
			log(None, CORE_POLICY, 2, "No applicable service found for this client zone; bindto='%s', client_zone='%s'" % (self.bindto, session.client_zone))
			s = None

		self.cache.store(cache_ndx, s)
		return s

class CSZoneDispatcher(Dispatcher):
	"""Class to listen on the selected address, and start a service based on the client's and the original server zone.
	
	This class is similar to a simple Dispatcher, but instead of
	starting a fixed service, it chooses one based on the client
	and the destined server zone.
	
	It takes a mapping of services indexed by a client and the server
	zone name, with an exception of the '*' zone, which matches
	anything.

	NOTE: the server zone might change during proxy and NAT processing,
	therefore the server zone used here only matches the real
	destination if those phases leave the server address intact.
	
	Attributes
	
	  services -- services mapping indexed by zone names
	"""
	
	def __init__(self, protocol, bindto, services, kw=None):
		"""Constructor to initialize a ZoneDispatcher instance.
		
		This constructor initializes a ZoneDispatcher instance and sets
		its initial attributes based on arguments.
		
		Arguments
		 
		  self -- this instance
		  
		  bindto -- bind to this address
		  
		  services -- a mapping between zone names and services
		  
		  follow_parent -- whether to follow the administrative hieararchy when finding the correct service
		"""
		self.follow_parent = getKeywordArg(kw, 'follow_parent', FALSE)
		Dispatcher.__init__(self, protocol, bindto, None, kw=kw)
		self.services = services
		self.cache = ShiftCache('csdispatch(%s)' % self.bindto, zone_dispatcher_shift_threshold)
	
	def getService(self, session):  
		"""Virtual function which returns the service to be ran

		This function is called by our base class to find out the  
		service to be used for the current session. It uses the
		client and the server zone name to decide which service to
		use.

		Arguments

		  self -- this instance

		  session -- session we are starting

		"""
		from Zone import root_zone
		dest_zone = root_zone.findZone(session.client_local)
		
		cache_ndx = (session.client_zone.getName(), dest_zone.getName())

		try:
			cached = self.cache.lookup(cache_ndx)
			if not cached:
				log(None, CORE_POLICY, 2, "No applicable service found for this client & server zone (cached); bindto='%s', client_zone='%s', server_zone='%s'" % (self.bindto, session.client_zone, dest_zone))
			return cached
		except KeyError:
			pass

		src_hierarchy = {}
		dst_hierarchy = {}
		if self.follow_parent:
			z = session.client_zone
			level = 0
			while z:
				src_hierarchy[z.getName()] = level
				z = z.admin_parent
				level = level + 1
			src_hierarchy['*'] = level
			max_level = level + 1
			z = dest_zone
			level = 0
			while z:
				dst_hierarchy[z.getName()] = level
				z = z.admin_parent
				level = level + 1
			dst_hierarchy['*'] = level
			max_level = max(max_level, level + 1)
		else:
			src_hierarchy[session.client_zone.getName()] = 0
			src_hierarchy['*'] = 1
			dst_hierarchy[dest_zone.getName()] = 0
			dst_hierarchy['*'] = 1
			max_level = 10

		best = None
		for spec in self.services.keys():
			try:
				src_level = src_hierarchy[spec[0]]
				dst_level = dst_hierarchy[spec[1]]
			except KeyError:
				src_level = max_level
				dst_level = max_level
				
			if not best or 							\
			   (best_src_level > src_level) or				\
			   (best_src_level == src_level and best_dst_level > dst_level):
				best = self.services[spec]
				best_src_level = src_level
				best_dst_level = dst_level

		if best_src_level < max_level and best_dst_level < max_level:
			s = Globals.services[best]
		else:
			log(None, CORE_POLICY, 2, "No applicable service found for this client & server zone; bindto='%s', client_zone='%s', server_zone='%s'" % (self.bindto, session.client_zone, dest_zone))
			s = None
		self.cache.store(cache_ndx, s)
		return s

def purgeDispatches():
	for i in Globals.dispatches:
		i.destroy()
	del Globals.dispatches
