#! /usr/bin/env python
# Copyright (C) 2005 Red Hat 
# see file 'COPYING' for use and warranty information
#
# Audit2allow is a rewrite of prior perl script.
#
# Based off original audit2allow perl script: which credits
#    newrules.pl, Copyright (C) 2001 Justin R. Smith (jsmith@mcs.drexel.edu)
#    2003 Oct 11: Add -l option by Yuichi Nakamura(ynakam@users.sourceforge.jp)
#
#    This program is free software; you can redistribute it and/or
#    modify it under the terms of the GNU General Public License as
#    published by the Free Software Foundation; either version 2 of
#    the License, or (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program; if not, write to the Free Software
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA     
#                                        02111-1307  USA
#
#  
import commands, sys, os, pwd, string, getopt, re, selinux

obj="(\{[^\}]*\}|[^ \t:]*)"
allow_regexp="allow[ \t]+%s[ \t]*%s[ \t]*:[ \t]*%s[ \t]*%s" % (obj, obj, obj, obj)

awk_script='/^[[:blank:]]*interface[[:blank:]]*\(/ {\n\
        IFACEFILE=FILENAME\n\
	IFACENAME = gensub("^[[:blank:]]*interface[[:blank:]]*\\\\(\`?","","g",$0);\n\
	IFACENAME = gensub("\'?,.*$","","g",IFACENAME);\n\
}\n\
\n\
/^[[:blank:]]*allow[[:blank:]]+.*;[[:blank:]]*$/ {\n\
\n\
  if ((length(IFACENAME) > 0) && (IFACEFILE == FILENAME)){\n\
		ALLOW = gensub("^[[:blank:]]*","","g",$0)\n\
		ALLOW = gensub(";[[:blank:]]*$","","g",$0)\n\
		print FILENAME "\\t" IFACENAME "\\t" ALLOW;\n\
	}\n\
}\
'

class accessTrans:
    def __init__(self):
        self.dict={}
	try:
		fd=open("/usr/share/selinux/devel/include/support/obj_perm_sets.spt")
	except IOError, error:
		raise IOError("Reference policy generation requires the policy development package.\n%s" % error)
        records=fd.read().split("\n")
        regexp="^define *\(`([^']*)' *, *` *\{([^}]*)}'"
        for r in records:
            m=re.match(regexp,r)
            if m!=None:
                self.dict[m.groups()[0]] = m.groups()[1].split()
        fd.close()
    def get(self, var):
        l=[]
        for v in var:
            if v in self.dict.keys():
                l += self.dict[v]
            else:
                if v not in ("{", "}"):
                    l.append(v)
        return l

class interfaces:
    def __init__(self):
        self.dict={}
        trans=accessTrans()
	(input, output) = os.popen2("awk -f - /usr/share/selinux/devel/include/*/*.if 2> /dev/null")
	input.write(awk_script)
	input.close()
	records=output.read().split("\n")
	input.close()
        if len(records) > 0:
            regexp="([^ \t]*)[ \t]+([^ \t]*)[ \t]+%s" % allow_regexp
            for r in records:
                m=re.match(regexp,r)
                if m==None:
                    continue
                else:
                    val=m.groups()
                file=os.path.basename(val[0]).split(".")[0]
                iface=val[1]
                Scon=val[2].split()
                Tcon=val[3].split()
                Class=val[4].split()
                Access=trans.get(val[5].split())
                for s in Scon:
                    for t in Tcon:
                        for c in Class:
                            if (s, t, c) not in self.dict.keys():
                                self.dict[(s, t, c)]=[]
                            self.dict[(s, t, c)].append((Access, file, iface))
    def out(self):
        keys=self.dict.keys()
        keys.sort()
        for k in keys:
            print k
            for i in self.dict[k]:
                print "\t", i
                
    def match(self, Scon, Tcon, Class, Access):
        keys=self.dict.keys()
        ret=[]
        if (Scon, Tcon, Class) in keys:
            for i in self.dict[(Scon, Tcon, Class)]:
                if Access in i[0]:
                    if i[2].find(Access) >= 0:
                        ret.insert(0, i)
                    else:
                        ret.append(i)
            return ret
        if ("$1", Tcon, Class) in keys:
            for i in self.dict[("$1", Tcon, Class)]:
                if Access in i[0]:
                    if i[2].find(Access) >= 0:
                        ret.insert(0, i)
                    else:
                        ret.append(i)
            return ret
        if (Scon, "$1", Class) in keys:
            for i in self.dict[(Scon, "$1", Class)]:
                if Access in i[0]:
                    if i[2].find(Access) >= 0:
                        ret.insert(0, i)
                    else:
                        ret.append(i)
            return ret
        else:
            return ret
        

class serule:
	def __init__(self, type, source, target, seclass):
		self.type=type
		self.source=source
		self.target=target
		self.seclass=seclass
		self.avcinfo={}
		self.iface=None
		
	def add(self, avc):
		for a in avc[0]:
			if a not in self.avcinfo.keys():
				self.avcinfo[a]=[]

			self.avcinfo[a].append(avc[1:])

	def getAccess(self):
		if len(self.avcinfo.keys()) == 1:
			for i in self.avcinfo.keys():
				return i
		else:
			keys=self.avcinfo.keys()
			keys.sort()
			ret="{"
			for i in keys:
				ret=ret + " " + i				
			ret=ret+" }"
			return ret
	def out(self, verbose=0):
		ret=""
		ret=ret+"%s %s %s:%s %s;" % (self.type, self.source, self.gettarget(), self.seclass, self.getAccess())
		if verbose:
			keys=self.avcinfo.keys()
			keys.sort()
			for i in keys:
				for x in self.avcinfo[i]:
					ret=ret+"\n\t#TYPE=AVC  MSG=%s  " % x[0]
					if len(x[1]):
						ret=ret+"COMM=%s  " % x[1]
					if len(x[2]):
						ret=ret+"NAME=%s  " % x[2]
					ret=ret + " : " + i 
		return ret
		
	def gen_reference_policy(self, iface):
		ret=""
		Scon=self.source
		Tcon=self.gettarget()
		Class=self.seclass
		Access=self.getAccess()
		m=iface.match(Scon,Tcon,Class,Access)
		if len(m)==0:
			return self.out()
		else:
			file=m[0][1]
			ret="\n#%s\n"% self.out()
			ret += "optional_policy(`%s', `\n" % m[0][1]
			first=True
			for i in m:
				if file != i[1]:
					ret += "')\ngen_require(`%s', `\n" % i[1]
					file = i[1]
					first=True
				if first:
					ret += "\t%s(%s)\n" % (i[2], Scon)
					first=False
				else:
					ret += "#\t%s(%s)\n" % (i[2], Scon)
			ret += "');"
		return ret
		
	def gettarget(self):
		if self.source == self.target:
			return "self"
		else:
			return self.target
	
class seruleRecords:
	def __init__(self, input, last_reload=0, verbose=0, te_ind=0):
		self.last_reload=last_reload
		self.seRules={}
		self.seclasses={}
		self.types=[]
		self.roles=[]
		self.load(input, te_ind)
		self.gen_ref_policy = False

	def gen_reference_policy(self):
		self.gen_ref_policy = True
		self.iface=interfaces()

	def warning(self, error):
		sys.stderr.write("%s: " % sys.argv[0])
		sys.stderr.write("%s\n" % error)
		sys.stderr.flush()

	def load(self, input, te_ind=0):
		VALID_CMDS=("allow", "dontaudit", "auditallow", "role")
		
		avc=[]
		found=0
		line = input.readline()
		if te_ind:
			while line:
				rec=line.split()
				if len(rec) and rec[0] in VALID_CMDS:
					self.add_terule(line)
				line = input.readline()
					
		else:
			while line:
				rec=line.split()
				for i in rec:
					if i=="avc:" or i=="message=avc:" or i=="msg='avc:":

						found=1
					else:
						avc.append(i)
				if found:
					self.add(avc)
					found=0
					avc=[]
				line = input.readline()
				

	def get_target(self, i, rule):
		target=[]
		if rule[i][0] == "{":
			for t in rule[i].split("{"):
				if len(t):
					target.append(t)
			i=i+1
			for s in rule[i:]:
				if s.find("}") >= 0:
					for s1 in s.split("}"):
						if len(s1):
							target.append(s1)
						i=i+1
						return (i, target)

				target.append(s)
				i=i+1
		else:
			if rule[i].find(";") >= 0:
				for s1 in rule[i].split(";"):
					if len(s1):
						target.append(s1)
			else:
				target.append(rule[i])

		i=i+1
		return (i, target)

	def rules_split(self, rules):
		(idx, target ) = self.get_target(0, rules)
		(idx, subject) = self.get_target(idx, rules)
		return (target, subject)

	def add_terule(self, rule):
		rc = rule.split(":")
		rules=rc[0].split()
		type=rules[0]
		if type == "role":
			print type
		(sources, targets) = self.rules_split(rules[1:])
		rules=rc[1].split()
		(seclasses, access) = self.rules_split(rules)
		for scon in sources:
			for tcon in targets:
				for seclass in seclasses:
					self.add_rule(type, scon, tcon, seclass,access)
		
	def add_rule(self, rule_type, scon, tcon, seclass, access, msg="", comm="", name=""):
		self.add_seclass(seclass, access)
		self.add_type(tcon)
		self.add_type(scon)
		if (rule_type, scon, tcon, seclass) not in self.seRules.keys():
			self.seRules[(rule_type, scon, tcon, seclass)]=serule(rule_type, scon, tcon, seclass)
				
		self.seRules[(rule_type, scon, tcon, seclass)].add((access, msg, comm, name ))

	def add(self,avc):
		scon=""
		tcon=""
		seclass=""
		comm=""
		name=""
		msg=""
		access=[]
		if "security_compute_sid" in avc:
			return
		
		if "load_policy" in avc and self.last_reload:
			self.seRules={}

		if "granted" in avc:
			return
		try:
			for i in range (0, len(avc)):
				if avc[i]=="{":
					i=i+1
					while i<len(avc) and avc[i] != "}":
						access.append(avc[i])
						i=i+1
					continue
			
				t=avc[i].split('=')
				if len(t) < 2:
					continue
				if t[0]=="scontext":
					context=t[1].split(":")
					scon=context[2]
					srole=context[1]
					continue
				if t[0]=="tcontext":
					context=t[1].split(":")
					tcon=context[2]
					trole=context[1]
					continue
				if t[0]=="tclass":
					seclass=t[1]
					continue
				if t[0]=="comm":
					comm=t[1]
					continue
				if t[0]=="name":
					name=t[1]
					continue
				if t[0]=="msg":
					msg=t[1]
					continue

			if scon=="" or tcon =="" or seclass=="":
				return
		except IndexError, e:
			self.warning("Bad AVC Line: %s" % avc)
			return
			
		self.add_role(srole)
		self.add_role(trole)
		self.add_rule("allow", scon, tcon, seclass, access, msg, comm, name)

	def add_seclass(self,seclass, access):
		if seclass not in self.seclasses.keys():
				self.seclasses[seclass]=[]
		for a in access:
			if a not in self.seclasses[seclass]:
				self.seclasses[seclass].append(a)
				
	def add_role(self,role):
		if role not in self.roles:
				self.roles.append(role)

	def add_type(self,type):
		if type not in self.types:
				self.types.append(type)

	def gen_module(self, module):
		return "module %s 1.0;" % module

	def gen_requires(self):
		self.roles.sort()
		self.types.sort()
		keys=self.seclasses.keys()
		keys.sort()
		rec="\n\nrequire {\n"
		if len(self.roles) > 0:
			for i in self.roles:
				rec += "\trole %s; \n" % i
			rec += "\n" 

		for i in keys:
			access=self.seclasses[i]
			if len(access) > 1:
				access.sort()
				rec += "\tclass %s {" % i
				for a in access:
					rec += " %s" % a
				rec += " }; \n"
			else:
				rec += "\tclass %s %s;\n" % (i, access[0])
				
		rec += "\n" 
			
		for i in self.types:
			rec += "\ttype %s; \n" % i
		rec += " };\n\n\n"
		return rec
	
	def out(self, require=0, module=""):
		rec=""
		if len(self.seRules.keys())==0:
		       raise(ValueError("No AVC messages found."))
		if module != "":
			rec += self.gen_module(module)
			rec += self.gen_requires()
		else:
			if requires:
				rec+=self.gen_requires()

		keys=self.seRules.keys()
		keys.sort()
		for i in keys:
			if self.gen_ref_policy:
				rec += self.seRules[i].gen_reference_policy(self.iface)+"\n"
			else:
				rec += self.seRules[i].out(verbose)+"\n"
		return rec

if __name__ == '__main__':

	def get_mls_flag():
		if selinux.is_selinux_mls_enabled():
			return "-M"
		else:
			return ""

	def usage(msg=""):
		print 'audit2allow [-adhilrv] [-t file ] [ -f fcfile ] [-i <inputfile> ] [[-m|-M] <modulename> ] [-o <outputfile>]\n\
		-a, --all        read input from audit and message log, conflicts with -i\n\
		-d, --dmesg      read input from output of /bin/dmesg\n\
		-h, --help       display this message\n\
		-i, --input      read input from <inputfile> conflicts with -a\n\
		-l, --lastreload read input only after last \"load_policy\"\n\
		-m, --module     generate module/require output <modulename> \n\
		-M               generate loadable module package, conflicts with -o\n\
		-o, --output     append output to <outputfile>, conflicts with -M\n\
		-r, --requires   generate require output \n\
		-t, --tefile     Indicates input is Existing Type Enforcement file\n\
		-f, --fcfile     Existing Type Enforcement file, requires -M\n\
		-v, --verbose    verbose output\n\
		'
		if msg != "":
			print msg
		sys.exit(1)
		
	def errorExit(error):
		sys.stderr.write("%s: " % sys.argv[0])
		sys.stderr.write("%s\n" % error)
		sys.stderr.flush()
		sys.exit(1)

	#
	# 
	#
	try:
		last_reload=0
		input=sys.stdin
		output=sys.stdout
		module=""
		requires=0
		verbose=0
		auditlogs=0
		buildPP=0
		input_ind=0
		output_ind=0
		ref_ind=False
		te_ind=0

		fc_file=""
		gopts, cmds = getopt.getopt(sys.argv[1:],
					    'adf:hi:lm:M:o:rtvR',
					    ['all',
					     'dmesg',
					     'fcfile=',
					     'help',
					     'input=',
					     'lastreload',
					     'module=',
					     'output=',
					     'requires',
					     'reference',
					     'tefile',
					     'verbose'
					     ])
		for o,a in gopts:
			if o == "-a" or o == "--all":
				if input_ind or te_ind:
					usage()
				input=open("/var/log/messages", "r")
				auditlogs=1
			if o == "-d"  or o == "--dmesg":
				input=os.popen("/bin/dmesg", "r")
			if o == "-f" or o == "--fcfile":
				if a[0]=="-":
					usage()
				fc_file=a
			if o == "-h" or o == "--help":
				usage()
			if o == "-i"or o == "--input":
				if auditlogs  or a[0]=="-":
					usage()
				input_ind=1
				input=open(a, "r")
			if o == '--lastreload' or o == "-l":
				last_reload=1
			if o == "-m" or o == "--module":
				if module != "" or a[0]=="-":
					usage()
				module=a
			if o == "-M":
				if module != "" or output_ind  or a[0]=="-":
					usage()
				module=a
				outfile=a+".te"
				buildPP=1
				output=open(outfile, "w")
			if o == "-r" or o == "--requires":
				requires=1
			if o == "-t" or o == "--tefile":
				if auditlogs:
					usage()
				te_ind=1
			if o == "-R" or o == "--reference":
				ref_ind=True
				
			if o == "-o" or o == "--output":
				if module != ""  or a[0]=="-":
					usage()
				output=open(a, "a")
				output_ind=1
			if o == "-v" or o == "--verbose":
				verbose=1
				
		if len(cmds) != 0:
			usage()

		if fc_file != "" and not buildPP:
			usage("Error %s: Option -fc requires -M" % sys.argv[0])
			
		out=seruleRecords(input, last_reload, verbose, te_ind)


		if ref_ind:
			out.gen_reference_policy()

		if auditlogs:
			input=os.popen("ausearch -m avc")
			out.load(input)

		if buildPP:
			print ("Generating type enforcment file: %s.te" % module)
		output.write(out.out(requires, module))
		output.flush()
		if buildPP:
			cmd="checkmodule %s -m -o %s.mod %s.te" % (get_mls_flag(), module, module)
			print "Compiling policy"
			print cmd
			rc=commands.getstatusoutput(cmd)
			if rc[0]==0:
				cmd="semodule_package -o %s.pp -m %s.mod" % (module, module)
				if fc_file != "":
					cmd = "%s -f %s" % (cmd, fc_file)
					
				print cmd
				rc=commands.getstatusoutput(cmd)
				if rc[0]==0:
					print ("\n******************** IMPORTANT ***********************\n")
					print ("In order to load this newly created policy package into the kernel,\nyou are required to execute \n\nsemodule -i %s.pp\n\n" % module)
				else:
					errorExit(rc[1])
			else:
				errorExit(rc[1])

	except getopt.error, error:
		errorExit("Options Error " + error.msg)
	except ValueError, error:
		errorExit(error.args[0])
	except IOError, error:
		errorExit(error)
	except KeyboardInterrupt, error:
		sys.exit(0)
