#!/usr/bin/python

import sys
from twisted.internet.defer import succeed, fail
from twisted.internet import defer, reactor
from twisted.internet.stdio import StandardIO
from twisted.protocols.basic import LineReceiver
from twisted.internet.protocol import ClientFactory
from twisted.python.failure import Failure
from ldaptor.protocols.ldap import ldapclient, distinguishedname, ldapconnector
from ldaptor.protocols.ldap.ldapsyntax import LDAPEntry
from ldaptor import usage
from cStringIO import StringIO

class ExitSentinel:
    pass

class Search(ldapclient.LDAPClient):
    pdnsPipeProtocol = None

    def connectionMade(self):
	d=self.bind()
	d.addCallback(self.bind_ok)
	d.addErrback(lambda dummy: self.transport.loseConnection())

    def bind_ok(self, dummy):
	o = LDAPEntry(client=self, dn=self.factory.base)
	self.pdnsPipeProtocol = PdnsPipeProtocol(o, self.factory.dnsDomain)
	reactor.addReader(StandardIO(self.pdnsPipeProtocol))

    def connectionLost(self, reason):
	if self.pdnsPipeProtocol:
	    self.pdnsPipeProtocol.connectionLost()

class SearchFactory(ClientFactory):
    protocol = Search
    ldapClient = None

    def __init__(self, base, dnsDomain):
	self.base=base
	self.dnsDomain=dnsDomain

    def buildProtocol(self, addr):
	p = ClientFactory.buildProtocol(self, addr)
	return p

class PdnsPipeProtocol(LineReceiver):
    delimiter = '\n'
    state = 'start'

    def __init__(self, ldapObject, dnsDomain):
	self.work=[]
	self.ldapObject = ldapObject
	self.dnsDomain = dnsDomain

    def _doWork(self):
	while self.work:
	    if isinstance(self.work[0][0], defer.Deferred):
		# end of done items, stop and wait for completions
		break
	    else:
		done=self.work.pop(0)

		if done == [ExitSentinel]:
		    # that's it, I'm outtahere
		    assert not self.work
		    reactor.stop()
		else:
		    for line in done:
			self.sendLine(line)
		    sys.stdout.flush()

    def completed(self, result, who):
	who[:]=result
	self._doWork()

    def failed(self, result, who):
	io = StringIO()
	result.printTraceback(file=io)
	who[:]=['LOG\t%s' % line
		for line in io.getvalue().splitlines()] \
		+ ['FAIL']
	self._doWork()

    def do_start_HELO(self, rest):
	if rest=='1':
	    self.state = 'main'
	    return succeed(['OK\t%s' % sys.argv[0]])
	else:
	    return succeed(['FAIL'])

    def _gotA(self, results, qname, qclass, qtype, ident):
	r=[]
	for o in results:
	    for ip in o.get('ipHostNumber', ()):
		r.append('\t'.join(('DATA', qname, qclass, qtype,
				    '3600', ident, ip)))
	r.append('END')
	return r

    def question_A(self, qname,qclass,ident,ipAddress):
	ident='-1'
	if not qname.endswith('.'+self.dnsDomain):
	    return succeed(['END'])
	cn=qname[:-len('.'+self.dnsDomain)]
	d=self.ldapObject.search(filterText='(&(cn=%s)(ipHostNumber=*))'%cn,
				 attributes=['ipHostNumber'])
	d.addCallback(self._gotA, qname, qclass, 'A', ident)
	return d

    def question_ANY(self, qname,qclass,ident,ipAddress):
	if qname.endswith('.in-addr.arpa'):
	    return self.question_PTR(qname,qclass,ident,ipAddress)
	else:
	    return self.question_A(qname,qclass,ident,ipAddress)

    def _gotPTR(self, results, qname, qclass, qtype, ident):
	r=[]
	for o in results:
	    for cn in o.get('cn', ()):
		r.append('\t'.join(('DATA', qname, qclass, qtype,
				    '3600', ident,
				    cn+'.'+self.dnsDomain+'.')))
	r.append('END')
	return r

    def question_PTR(self, qname,qclass,ident,ipAddress):
	ident='-1'
	if not qname.endswith('.in-addr.arpa'):
	    return succeed(['END'])

	octets=qname[:-len('.in-addr.arpa')].split('.')
	if len(octets)!=4:
	    return succeed(['END'])
	octets.reverse()
	ip='.'.join(octets)
	d=self.ldapObject.search(filterText='(ipHostNumber=%s)'%ip,
				 attributes=['cn'])
	d.addCallback(self._gotPTR, qname, qclass, 'PTR', ident)
	return d

    def do_main_Q(self, rest):
	try:
	    qname,qclass,qtype,ident,ipAddress=rest.split('\t', 4)
	except ValueError:
	    return succeed(['LOG\tInvalid question: %s' % repr(rest),
			    'END'])
	if qclass!='IN':
	    return succeed(['LOG\tInvalid qclass: %s' % repr(qclass),
			    'END'])
	q=getattr(self, 'question_'+qtype, None)
	if q:
	    return q(qname,qclass,ident,ipAddress)
	else:
	    return succeed(['END'])

    def do_main_AXFR(self, rest):
	return succeed(['LOG\tRefusing AXFR', 'END'])

    def do_main_PING(self, rest):
	#TODO it's undocumented what I should be saying
	return succeed(['END'])

    def lineReceived(self, line):
	try:
	    try:
		type,rest=line.split('\t', 1)
	    except ValueError:
		type=line
		rest=''
	    f=getattr(self, 'do_'+self.state+'_'+type, None)
	    if f:
		d=f(rest)
	    else:
		d=succeed(['LOG\tUnknown command %s in state %s'
			   % (repr(type), self.state),
			   'END'])
	except:
	    f=Failure()
	    d=fail(f)

	l=[d]
	self.work.append(l)
	d.addCallback(self.completed, l)
	d.addErrback(self.failed, l)

    def connectionLost(self):
	self.work.append([ExitSentinel])
	self._doWork()

exitStatus = 0

def main(base, serviceLocationOverride, dnsDomain):
    s=SearchFactory(base, dnsDomain)
    dn = distinguishedname.DistinguishedName(stringValue=base)
    c=ldapconnector.LDAPConnector(reactor, dn, s, overrides=serviceLocationOverride)
    c.connect()
    reactor.run()
    sys.exit(exitStatus)

class MyOptions(usage.Options, usage.Options_service_location, usage.Options_base):
    """LDAPtor PDNS pipe backend"""

    optParameters = (
	('dns-domain', None, 'example.com',
	 "DNS domain name"),
	)

if __name__ == "__main__":
    try:
	config = MyOptions()
	config.parseOptions()
    except usage.UsageError, ue:
	sys.stderr.write('%s: %s\n' % (sys.argv[0], ue))
	sys.exit(1)

    from twisted.python import log
    log.startLogging(sys.stderr, setStdout=0)
    main(config.opts['base'],
	 config.opts['service-location'],
	 config.opts['dns-domain'])
