#!/usr/bin/python

#host fantasia {
#  dhcp-client-identifier
#  hardware ethernet 08:00:07:26:c0:a5;
#  fixed-address fantasia.fugue.com;
#}

#subnet 1.2.3.0 netmask 255.255.255.0 {
#  option routers 1.2.3.4;
#  range 1.2.3.100 1.2.3.200;
#  option domain-name "foo.bar.example.com";
#}

#shared-network "foo" {
#}

from ldaptor.protocols.ldap import ldapclient, distinguishedname, ldapconnector
from ldaptor.protocols import pureber, pureldap
from ldaptor import usage, ldapfilter
from twisted.internet import protocol, reactor, defer
from socket import inet_aton, inet_ntoa


def my_aton_octets(ip):
    s=inet_aton(ip)
    octets=map(None, s)
    n=0L
    for o in octets:
	n=n<<8
	n+=ord(o)
    return n

def my_aton_numbits(num):
    n=0L
    while num>0:
	n>>=1
	n |= 2**31
	num-=1
    return n

def my_aton(ip):
    try:
	i=int(ip)
    except ValueError:
	return my_aton_octets(ip)
    else:
	return my_aton_numbits(i)

def my_ntoa(n):
    s=(
	chr((n>>24)&0xFF)
	+ chr((n>>16)&0xFF)
	+ chr((n>>8)&0xFF)
	+ chr(n&0xFF)
       )
    ip=inet_ntoa(s)
    return ip

class HostIPAddress:
    def __init__(self, host, ipAddress):
	self.host=host
	self.ipAddress=ipAddress

    def printDHCP(self, domain, prefix=''):
	r=([
	    '# %s' % self.host.dn,
	    'host %s.%s {' % (self.host.name, domain),
	    ]
	   + [ '\thardware ethernet %s;' % mac
	       for mac in self.host.macAddresses
	       ]
	   + [
	    '\tfixed-address %s;' % self.ipAddress,
	    '}'
	    ])
	print '\n'.join([prefix+line for line in r])

    def __repr__(self):
	return (self.__class__.__name__
		+'('
		+'host=%s, ' % id(self.host)
		+'ipAddress=%s' % repr(self.ipAddress)
		+')')

class Host:
    def __init__(self, dn, name, ipAddresses, macAddresses=()):
	self.dn=dn
	self.name=name
	self.ipAddresses=[HostIPAddress(self, ip) for ip in ipAddresses]
	self.macAddresses=macAddresses

    def __repr__(self):
	return (self.__class__.__name__
		+'('
		+'dn=%s, ' % repr(self.dn)
		+'name=%s, ' % repr(self.name)
		+'ipAddresses=%s' % repr(self.ipAddresses)
		+'macAddresses=%s' % repr(self.macAddresses)
		+')')

class Net:
    def __init__(self, dn, name, address, mask,
		 routers=(),
		 dhcpRanges=(),
		 winsServers=(),
		 domainNameServers=(),
		 ):
	self.dn=dn
	self.name=name
	self.address=address
	self.mask=mask
	self.routers=routers
	self.dhcpRanges=dhcpRanges
	self.winsServers=winsServers
	self.domainNameServers=domainNameServers
	self.hosts=[]

    def isInNet(self, ipAddress):
	net = my_aton(self.address)
	mask = my_aton(self.mask)
	ip = my_aton(ipAddress)
	if ip&mask == net:
	    return 1
	return 0

    def addHost(self, host):
	assert self.isInNet(host.ipAddress)
	self.hosts.append(host)

    def printDHCP(self, domain, prefix=''):
	nm = self.mask
	nm = my_aton(nm)
	nm = my_ntoa(nm)
	r = ['# %s' % self.dn,
	     'subnet %s netmask %s {' % (self.address, nm),
	     '\toption domain-name "%s.%s";' % (self.name, domain)]
	if self.routers:
	    r.append('\toption routers %s;' % (', '.join(self.routers)))
	for dhcpRange in self.dhcpRanges:
	    r.append('\trange %s;' % dhcpRange)
	if self.winsServers:
	    r.append('\toption netbios-name-servers %s;' % (', '.join(self.winsServers)))
	if self.domainNameServers:
	    r.append('\toption domain-name-servers %s;' % (', '.join(self.domainNameServers)))
	r.append('}')

	print '\n'.join([prefix+line for line in r])

	seen = {}
	for host in self.hosts:
	    if seen.has_key(host.host):
		continue
	    seen[host.host]=1
	    host.printDHCP(self.name+'.'+domain, prefix=prefix)

    def __repr__(self):
	return (self.__class__.__name__
		+'('
		+'dn=%s, ' % repr(self.dn)
		+'name=%s, ' % repr(self.name)
		+'address=%s, ' % repr(self.address)
		+'mask=%s' % repr(self.mask)
		+')')

class SharedNet:
    def __init__(self, name):
	self.name=name
	self.nets=[]

    def addNet(self, net):
	self.nets.append(net)

    def printDHCP(self, domain):
	print 'shared-network "%s" {' % self.name
	for net in self.nets:
	    net.printDHCP(domain, prefix='\t')
	print '}'
	print

class SearchHosts(ldapclient.LDAPSearch):
    def __init__(self, deferred, client, base, filter):
	self.entries = []
	filt=pureldap.LDAPFilter_and(value=(
	    pureldap.LDAPFilter_present('cn'),
	    pureldap.LDAPFilter_present('ipHostNumber'),
	    ))
	if filter:
	    filt = pureldap.LDAPFilter_and(value=(filter, filt))
	ldapclient.LDAPSearch.__init__(self, deferred, client,
				       baseObject=base,
				       filter=filt,
				       attributes=['cn',
						   'ipHostNumber',
						   'macAddress',
						   ])
	deferred.addCallback(lambda searchHosts: searchHosts.entries)

    def handle_entry(self, objectName, attributes):
	args = {}
	for k,vs in attributes:
	    k=str(k)
	    args[k]=vs

	assert len(args['cn'])==1, \
	       "object %s attribute 'cn' has multiple values: %s" \
	       % (objectName, args['cn'])

	self.entries.append(Host(str(objectName),
				 str(args['cn'][0]),
				 map(str, args['ipHostNumber']),
				 map(str, args.get('macAddress', ()))))

class SearchNets(ldapclient.LDAPSearch):
    def __init__(self, deferred, client, base, filter):
	self.sharedNetworks = {}
	self.entries = []

	filt=pureldap.LDAPFilter_and(value=(
	    pureldap.LDAPFilter_present('cn'),
	    pureldap.LDAPFilter_present('ipNetworkNumber'),
	    pureldap.LDAPFilter_present('ipNetmaskNumber'),
	    ))
	if filter:
	    filt = pureldap.LDAPFilter_and(value=(filter, filt))
	ldapclient.LDAPSearch.__init__(self, deferred, client,
				       baseObject=base,
				       filter=filt,
				       attributes=['cn',
						   'ipNetworkNumber',
						   'ipNetmaskNumber',
						   'router',
						   'dhcpRange',
						   'winsServer',
						   'domainNameServer',
						   'sharedNetworkName'])
	deferred.addCallback(lambda x: (x.entries, x.sharedNetworks))

    def handle_entry(self, objectName, attributes):
	args = {}
	for k,vs in attributes:
	    k=str(k)
	    args[k]=map(str, vs)

	assert len(args['cn'])==1, \
	       "object %s attribute 'cn' has multiple values: %s" \
	       % (objectName, args['cn'])
	cn=args['cn'][0]
	assert len(args['ipNetworkNumber'])==1, \
	       "object %s attribute 'ipNetworkNumber' has multiple values: %s" \
	       % (objectName, args['ipNetworkNumber'])
	ipNetworkNumber=args['ipNetworkNumber'][0]
	assert len(args['ipNetmaskNumber'])==1, \
	       "object %s attribute 'ipNetmaskNumber' has multiple values: %s" \
	       % (objectName, args['ipNetmaskNumber'])
	ipNetmaskNumber=args['ipNetmaskNumber'][0]
	net = Net(objectName, cn,
		  ipNetworkNumber, ipNetmaskNumber,
		  routers=args.get('router', ()),
		  dhcpRanges=args.get('dhcpRange', ()),
		  winsServers=args.get('winsServer', ()),
		  domainNameServers=args.get('domainNameServer', ()),
		  )
	if args.has_key('sharedNetworkName'):
	    assert len(args['sharedNetworkName'])==1, \
		   "object %s attribute 'sharedNetworkName' has multiple values: %s" \
		   % (objectName, args['sharedNetworkName'])
	    name = args['sharedNetworkName'][0]
	    if not self.sharedNetworks.has_key(name):
		self.sharedNetworks[name]=SharedNet(name)
	    self.sharedNetworks[name].addNet(net)
	else:
	    self.entries.append(net)

class Search(ldapclient.LDAPClient):
    def __init__(self):
	ldapclient.LDAPClient.__init__(self)

    def connectionMade(self):
	d=self.bind()
	d.addCallbacks(callback=self._handle_bind_success)
	d.addErrback(defer.logError)
	d.chainDeferred(self.factory.deferred)

    def _handle_bind_success(self, x):
	d1=defer.Deferred()
	SearchNets(d1, self, self.factory.base, self.factory.filt)
	d1.addCallbacks(callback=self.haveNets,
			errback=defer.logError)
	return d1


    def haveNets(self, data):
	nets, sharedNets = data
	self.nets = nets
	self.sharedNets = sharedNets
	d=defer.Deferred()
	SearchHosts(d, self, self.factory.base, self.factory.filt)
	d.addCallbacks(callback=self.haveHosts,
		       errback=defer.logError)
	return d

    def haveHosts(self, hosts):
	for host in hosts:
	    for hostIP in host.ipAddresses:
		parent=None
		for net in self.nets + reduce(lambda x,y: x+y,
					      [x.nets for x in self.sharedNets.values()],
					      []):
		    if net.isInNet(hostIP.ipAddress):
			parent=net
			break

		if parent:
		    parent.addHost(hostIP)
		else:
		    sys.stderr.write("IP address %s is in no net, discarding.\n" % hostIP)

	for net in self.sharedNets.values():
	    net.printDHCP(self.factory.dnsDomain)
	for net in self.nets:
	    net.printDHCP(self.factory.dnsDomain)

class SearchFactory(protocol.ClientFactory):
    protocol = Search

    def __init__(self, deferred, base, filt, dnsDomain):
	self.deferred=deferred
	self.base=base
	self.filt=filt
	self.dnsDomain=dnsDomain

    def clientConnectionFailed(self, connector, reason):
	self.deferred.errback(reason)

exitStatus=0

def error(fail):
    print >>sys.stderr, 'fail:', fail.getErrorMessage()
    global exitStatus
    exitStatus=1

def main(base, serviceLocationOverride, filter_text, dnsDomain):
    from twisted.python import log
    log.startLogging(sys.stderr, setStdout=0)

    if filter_text is not None:
	filt = ldapfilter.parseFilter(filter_text)
    else:
	filt = None

    d=defer.Deferred()
    s=SearchFactory(d, base, filt, dnsDomain)
    d.addErrback(error)
    d.addBoth(lambda x: reactor.stop())
    dn = distinguishedname.DistinguishedName(stringValue=base)
    c=ldapconnector.LDAPConnector(reactor, dn, s, overrides=serviceLocationOverride)
    c.connect()
    reactor.run()
    sys.exit(exitStatus)

class MyOptions(usage.Options, usage.Options_service_location, usage.Options_base):
    """LDAPtor dhcpd config file exporter"""

    optParameters = (
	('dns-domain', None, 'example.com',
	 "DNS domain name"),
	)

    def parseArgs(self, filter=None):
	self.opts['filter'] = filter

if __name__ == "__main__":
    import sys
    try:
	config = MyOptions()
	config.parseOptions()
    except usage.UsageError, ue:
	sys.stderr.write('%s: %s\n' % (sys.argv[0], ue))
	sys.exit(1)

    main(config.opts['base'],
	 config.opts['service-location'],
	 config.opts['filter'],
	 config.opts['dns-domain'],
	 )
