#!/usr/bin/env python3

import subprocess
import os
import time
import sys

from argparse import ArgumentParser
from signal import SIGKILL

#globals
DEFAULT_MARGIN = 0.1
DEFAULT_PROCESSES = 100
DEFAULT_SAMPLES = 3
DEFAULT_SLEEP = 30


#process memory usage function
def mem_usage(process, samples=1):
    usage_sum = 0

    for n in range(1, samples + 1):
        pid = subprocess.getoutput("pidof %s" % process)
        output = subprocess.getoutput("pmap -d %s" % pid)
        if not output:
            raise Exception("%s process not found" % process)

        # mapped: 0K    writeable/private: 0K    shared: 0K
        lastline = output.splitlines()[-1]
        values = lastline.split()
        # use the mapped memory counter, remove trailing K.
        usage = int(values[1][:-1])
        usage_sum += usage

        if n < samples:
            time.sleep(3)

    return usage_sum / samples


def get_gem_objects():
    dri = subprocess.getoutput("xdriinfo")
    if not dri.startswith('Screen'):
        print("xdriinfo fail : %s" % (dri), file=sys.stderr)
        sys.exit(1)
    screen = dri.split()[1][:-1]
    path = os.path.join("/sys/kernel/debug/dri/", screen, "i915_gem_objects")

    # 2432 objects, 173256704 bytes
    try:
        output = open(path, 'r')
    except IOError:
        print("File %s doesn't exist" % (path), file=sys.stderr)
        sys.exit(1)
    objects_label, bytes_label = output.readline().split(", ")
    objects, label = objects_label.split(" ")
    return int(objects)


#compare pre-test and post-test video card mem values function
#there is a 10% margin
def compare_mem_results(orig, new, margin=0.1):
    return new < (orig + (orig * margin)) and new > (orig - (orig * margin))


def compare_gem_results(orig, new, margin=0.1):
    return new < (orig * (1 + margin)) and new > (orig * (1 - margin))


def start_processes(name, number=1):
    pids = []
    for n in range(number):
        pid = subprocess.Popen(name).pid
        pids.append(pid)

    return pids


def stop_processes(pids, signal=SIGKILL):
    for pid in pids:
        os.kill(pid, signal)


def main():
    #Parse options
    parser = ArgumentParser()
    parser.add_argument("program",
        nargs=1,
        help="Specify an X program to create memory load with.")
    parser.add_argument("-m", "--margin",
        default=DEFAULT_MARGIN,
        type=float,
        help="Margin of error for memory usage [Default: %(default)s]")
    parser.add_argument("-p", "--processes",
        default=DEFAULT_PROCESSES,
        type=int,
        help="Number of processes to start and stop [Default: %(default)s]")
    parser.add_argument("--samples",
        default=DEFAULT_SAMPLES,
        type=int,
        help="Number of samples to get the memory usage [Default: %(default)s]")
    parser.add_argument("--sleep",
        default=DEFAULT_SLEEP,
        type=int,
        help=("Seconds to sleep between starting and stopping processes [Default: %(default)s]"))
    args = parser.parse_args()

    #Check current video card driver memory usage
    begin_mem_usage = mem_usage("X", args.samples)

    #Check current GEM object usage
    begin_gem_obj = get_gem_objects()

    #Open windows and let the system come upto speed
    pids = start_processes(args.program, args.processes)
    time.sleep(args.sleep)

    #Close windows
    stop_processes(pids)
    time.sleep(args.sleep)

    #Check video card driver's memory usage again to see if it returned to
    #value found previously (give system time to normalize)
    end_mem_usage = mem_usage("X", args.samples)

    #Check GEM object usage again, for state after running processes
    end_gem_obj = get_gem_objects()

    #compare new memory value to old memory value
    if not compare_mem_results(begin_mem_usage, end_mem_usage, args.margin):
        return "Xorg memory leak detected, before %d and after %d" \
            % (begin_mem_usage, end_mem_usage)
    if not compare_gem_results(begin_gem_obj, end_gem_obj, args.margin):
        return "DRI GEM objects leak detected, before %d and after %d" \
            % (begin_gem_obj, end_gem_obj)

    return 0

if __name__ == "__main__":
    sys.exit(main())
