#! /usr/bin/env python
# Copyright (C) 2005 Red Hat 
# see file 'COPYING' for use and warranty information
#
#    chcat is a script that allows you modify the Security label on a file
#
#`   Author: Daniel Walsh <dwalsh@redhat.com>
#
#    This program is free software; you can redistribute it and/or
#    modify it under the terms of the GNU General Public License as
#    published by the Free Software Foundation; either version 2 of
#    the License, or (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program; if not, write to the Free Software
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA     
#                                        02111-1307  USA
#
#  
import commands, sys, os, pwd, string, getopt, re, selinux
import seobject

def verify_users(users):
    for u in users:
        try:
            pwd.getpwnam(u)
        except KeyError, e:
            error( "User %s does not exist" % u)

def chcat_user_add(orig, newcat, users):
    errors=0
    logins = seobject.loginRecords()
    seusers=logins.get_all()
    add_ind=0
    verify_users(users)
    for u in users:
        if u in seusers.keys():
            user=seusers[u]
        else:
            add_ind=1
            user=seusers["__default__"]
        range=user[1].split("-")
        cats=[]
        top=["s0"]
        if len(range) > 1:
            top=range[1].split(":")
            if len(top) > 1:
                cats.append(top[1])
                cats = expandCats(cats)

        for i in newcat[1:]:
            if i not in cats:
                cats.append(i)
        new_range="%s-%s:%s" % (range[0], top[0], string.join(cats, ","))
        
        if add_ind:
            logins.add(u, user[0], new_range)
        else:
            logins.modify(u, user[0], new_range)
    return errors
        
def chcat_add(orig, newcat, objects,login_ind):
    if len(newcat) == 1:
        raise ValueError("Requires at least one category")

    if login_ind == 1:
        return chcat_user_add(orig, newcat, objects)
    
    errors=0
    sensitivity=newcat[0]
    cat=newcat[1]
    cmd='chcon -l %s' % sensitivity
    for f in objects:
        (rc, c) = selinux.getfilecon(f)
        con=c.split(":")[3:]
        clist = translate(con)
        if sensitivity != clist[0]:
                print("Can not modify sensitivity levels using '+' on %s" % f)

        if len(clist) > 1:
            if cat in clist[1:]:
                print "%s is already in %s" % (f, orig)
                continue
            clist.append(cat)
            cats=clist[1:]
            cats.sort()
            cat_string=cats[0]
            for c in cats[1:]:
                cat_string="%s,%s" % (cat_string, c)
        else:
            cat_string=cat
        cmd='chcon -l %s:%s %s' % (sensitivity, cat_string, f)
        rc=commands.getstatusoutput(cmd)
        if rc[0] != 0:
            print rc[1]
            errors+=1
    return errors

def chcat_user_remove(orig, newcat, users):
    errors=0
    logins = seobject.loginRecords()
    seusers=logins.get_all()
    add_ind=0
    verify_users(users)
    for u in users:
        if u in seusers.keys():
            user=seusers[u]
        else:
            add_ind=1
            user=seusers["__default__"]
        range=user[1].split("-")
        cats=[]
        top=["s0"]
        if len(range) > 1:
            top=range[1].split(":")
            if len(top) > 1:
                cats.append(top[1])
                cats = expandCats(cats)

        for i in newcat[1:]:
            if i in cats:
                cats.remove(i)

        new_range="%s-%s:%s" % (range[0], top[0], string.join(cats, ","))
        
        if add_ind:
            logins.add(u, user[0], new_range)
        else:
            logins.modify(u, user[0], new_range)
    return errors
        
def chcat_remove(orig, newcat, objects, login_ind):
    if len(newcat) == 1:
        raise ValueError("Requires at least one category")

    if login_ind == 1:
        return chcat_user_remove(orig, newcat, objects)

    errors=0
    sensitivity=newcat[0]
    cat=newcat[1]

    for f in objects:
        (rc, c) = selinux.getfilecon(f)
        con=c.split(":")[3:]
        clist = translate(con)
        if sensitivity != clist[0]:
                print("Can not modify sensitivity levels using '+' on %s" % f)
                continue
            
        if len(clist) > 1:
            if cat not in clist[1:]:
                print "%s is not in %s" % (f, orig)
                continue
            clist.remove(cat)
            if len(clist) > 1:
                cat=clist[1]
                for c in clist[2:]:
                    cat="%s,%s" % (cat, c)
            else:
                cat=""
        else:
                print "%s is not in %s" % (f, orig)
                continue
        
        if len(cat) == 0: 
            cmd='chcon -l %s %s' % (sensitivity, f)
        else:
            cmd='chcon -l %s:%s %s' % (sensitivity,cat, f)
        rc=commands.getstatusoutput(cmd)
        if rc[0] != 0:
            print rc[1]
            errors+=1
    return errors

def chcat_user_replace(orig, newcat, users):
    errors=0
    logins = seobject.loginRecords()
    seusers=logins.get_all()
    add_ind=0
    verify_users(users)
    for u in users:
        if u in seusers.keys():
            user=seusers[u]
        else:
            add_ind=1
            user=seusers["__default__"]
        range=user[1].split("-")
        new_range="%s-%s:%s" % (range[0],newcat[0], string.join(newcat[1:], ","))
        
        if add_ind:
            logins.add(u, user[0], new_range)
        else:
            logins.modify(u, user[0], new_range)
    return errors
    
def chcat_replace(orig, newcat, objects, login_ind):
    if login_ind == 1:
        return chcat_user_replace(orig, newcat, objects)
    errors=0
    if len(newcat) == 1:
        sensitivity=newcat[0]
        cmd='chcon -l %s ' % newcat[0]
    else:
        sensitivity=newcat[0]
        cmd='chcon -l %s:%s' % (sensitivity, newcat[1])
        for cat in newcat[2:]:
            cmd='%s,%s' % (cmd, cat)
        
    for f in objects:
        cmd = "%s %s" % (cmd, f)

    rc=commands.getstatusoutput(cmd)
    if rc[0] != 0:
        print rc[1]
        errors += 1

    return errors

def check_replace(cats):
    plus_ind=0
    replace_ind=0
    for c in cats:
        if len(c) > 0 and ( c[0] == "+" or c[0] == "-" ):
            if replace_ind:
                raise ValueError("Can not combine +/- with other types of categories")
            plus_ind=1
        else:
            replace_ind=1
            if plus_ind:
                raise ValueError("Can not combine +/- with other types of categories")
    return replace_ind

def isSensitivity(sensitivity):
    if sensitivity[0] == "s" and sensitivity[1:].isdigit() and int(sensitivity[1:]) in range(0,16):
        return 1
    else:
        return 0
    
def expandCats(cats):
    newcats=[]
    for c in cats:
        if c.find(".") != -1:
            c=c.split(".")
            for i in range(int(c[0][1:]), int(c[1][1:])+1):
                x=("c%d" % i)
                if x not in newcats:
                    newcats.append("c%d" % i)
        else:
            for i in c.split(","):
                if i not in newcats:
                    newcats.append(i)
    return newcats

def translate(cats):
    newcat=[]
    if len(cats) == 0:
        newcat.append("s0")
        return newcat
    for c in cats:
        (rc, raw) = selinux.selinux_trans_to_raw_context("a:b:c:%s" % c)
        rlist=raw.split(":")[3:]
        tlist=[]
        if isSensitivity(rlist[0])==0:
            tlist.append("s0")
            for i in expandCats(rlist):
                tlist.append(i)
        else:
            tlist.append(rlist[0])
            for i in expandCats(rlist[1:]):
                tlist.append(i)
        if len(newcat) == 0:
            newcat.append(tlist[0])
        else:
            if newcat[0] != tlist[0]:
                raise ValueError("Can not have multiple sensitivities")
        for i in tlist[1:]:
            newcat.append(i)
    return newcat
    
def usage():
	print "Usage %s CATEGORY File ..." % sys.argv[0]
	print "Usage %s -l CATEGORY user ..." % sys.argv[0]
	print "Usage %s [[+|-]CATEGORY],...]q File ..." % sys.argv[0]
	print "Usage %s -l [[+|-]CATEGORY],...]q user ..." % sys.argv[0]
	print "Usage %s -d File ..." % sys.argv[0]
	print "Usage %s -l -d user ..." % sys.argv[0]
	print "Usage %s -L" % sys.argv[0]
	print "Usage %s -L -l user" % sys.argv[0]
        print "Use -- to end option list.  For example"
        print "chcat -- -CompanyConfidential /docs/businessplan.odt"
        print "chcat -l +CompanyConfidential juser"
	sys.exit(1)

def listcats():
    fd = open(selinux.selinux_translations_path())
    for l in fd.read().split("\n"):
        if l.startswith("#"):
            continue
        if l.find("=")!=-1:
            rec=l.split("=")
            print "%-30s %s" % tuple(rec)
    fd.close()
    return 0
    

def listusercats(users):
    if len(users) == 0:
        users.append(os.getlogin())

    verify_users(users)
    for u in users:
        cats=seobject.translate(selinux.getseuserbyname(u)[2])
        cats=cats.split("-")
        if len(cats) > 1 and cats[1] != "s0":
            print "%s: %s" % (u, cats[1])
        else:
            print "%s: %s" % (u, cats[0])
            
def error(msg):
    print "%s: %s" % (sys.argv[0], msg)
    sys.exit(1)
    
if __name__ == '__main__':
    if selinux.is_selinux_mls_enabled() != 1:
        error("Requires a mls enabled system")
        
    if selinux.is_selinux_enabled() != 1:
        error("Requires an SELinux enabled system")
        
    delete_ind=0
    list_ind=0
    login_ind=0
    try:
        gopts, cmds = getopt.getopt(sys.argv[1:],
                                    'dhlL',
                                    ['list',
                                     'login',
                                     'help',
                                     'delete'])

        for o,a in gopts:
            if o == "-h" or o == "--help":
                usage()
            if o == "-d" or o == "--delete":
                delete_ind=1
            if o == "-L" or o == "--list":
                list_ind=1
            if o == "-l" or o == "--login":
                login_ind=1

        if list_ind==0 and len(cmds) < 1:
            usage()
    except ValueError, e:
        usage()

    if delete_ind:
        sys.exit(chcat_replace(["s0"], ["s0"], cmds, login_ind))

    if list_ind:
        if login_ind:
            sys.exit(listusercats(cmds))
        else:
            if len(cmds) > 0:
                usage()
            sys.exit(listcats())

    if len(cmds) < 2:
        usage()
    
    set_ind=0
    cats=cmds[0].split(",")
    mod_ind=0
    errors=0
    objects=cmds[1:]
    try:
        if check_replace(cats):
            errors=chcat_replace(cats,translate(cats), objects, login_ind)
        else:
            for c in cats:
                l=[]
                l.append(c[1:])
                if len(c) > 0 and c[0] == "+":
                    errors += chcat_add(c[1:],translate(l), objects, login_ind)
                    continue
                if len(c) > 0 and c[0] == "-":
                    errors += chcat_remove(c[1:],translate(l), objects, login_ind)
                    continue
    except ValueError, e:
        error(e)
    
    sys.exit(errors)
    


