#!/usr/bin/python

import curses
import sys, os, time, optparse

class DebugfsProvider(object):
    def __init__(self):
        self.base = '/sys/kernel/debug/kvm'
        self._fields = os.listdir(self.base)
    def fields(self):
        return self._fields
    def select(self, fields):
        self._fields = fields
    def read(self):
        def val(key):
            return int(file(self.base + '/' + key).read())
        return dict([(key, val(key)) for key in self._fields])

vmx_exit_reasons = {
    0: 'EXCEPTION_NMI',
    1: 'EXTERNAL_INTERRUPT',
    2: 'TRIPLE_FAULT',
    7: 'PENDING_INTERRUPT',
    8: 'NMI_WINDOW',
    9: 'TASK_SWITCH',
    10: 'CPUID',
    12: 'HLT',
    14: 'INVLPG',
    15: 'RDPMC',
    16: 'RDTSC',
    18: 'VMCALL',
    19: 'VMCLEAR',
    20: 'VMLAUNCH',
    21: 'VMPTRLD',
    22: 'VMPTRST',
    23: 'VMREAD',
    24: 'VMRESUME',
    25: 'VMWRITE',
    26: 'VMOFF',
    27: 'VMON',
    28: 'CR_ACCESS',
    29: 'DR_ACCESS',
    30: 'IO_INSTRUCTION',
    31: 'MSR_READ',
    32: 'MSR_WRITE',
    33: 'INVALID_STATE',
    36: 'MWAIT_INSTRUCTION',
    39: 'MONITOR_INSTRUCTION',
    40: 'PAUSE_INSTRUCTION',
    41: 'MCE_DURING_VMENTRY',
    43: 'TPR_BELOW_THRESHOLD',
    44: 'APIC_ACCESS',
    48: 'EPT_VIOLATION',
    49: 'EPT_MISCONFIG',
    54: 'WBINVD',
    55: 'XSETBV',
}

svm_exit_reasons = {
    0x000: 'READ_CR0',
    0x003: 'READ_CR3',
    0x004: 'READ_CR4',
    0x008: 'READ_CR8',
    0x010: 'WRITE_CR0',
    0x013: 'WRITE_CR3',
    0x014: 'WRITE_CR4',
    0x018: 'WRITE_CR8',
    0x020: 'READ_DR0',
    0x021: 'READ_DR1',
    0x022: 'READ_DR2',
    0x023: 'READ_DR3',
    0x024: 'READ_DR4',
    0x025: 'READ_DR5',
    0x026: 'READ_DR6',
    0x027: 'READ_DR7',
    0x030: 'WRITE_DR0',
    0x031: 'WRITE_DR1',
    0x032: 'WRITE_DR2',
    0x033: 'WRITE_DR3',
    0x034: 'WRITE_DR4',
    0x035: 'WRITE_DR5',
    0x036: 'WRITE_DR6',
    0x037: 'WRITE_DR7',
    0x040: 'EXCP_BASE',
    0x060: 'INTR',
    0x061: 'NMI',
    0x062: 'SMI',
    0x063: 'INIT',
    0x064: 'VINTR',
    0x065: 'CR0_SEL_WRITE',
    0x066: 'IDTR_READ',
    0x067: 'GDTR_READ',
    0x068: 'LDTR_READ',
    0x069: 'TR_READ',
    0x06a: 'IDTR_WRITE',
    0x06b: 'GDTR_WRITE',
    0x06c: 'LDTR_WRITE',
    0x06d: 'TR_WRITE',
    0x06e: 'RDTSC',
    0x06f: 'RDPMC',
    0x070: 'PUSHF',
    0x071: 'POPF',
    0x072: 'CPUID',
    0x073: 'RSM',
    0x074: 'IRET',
    0x075: 'SWINT',
    0x076: 'INVD',
    0x077: 'PAUSE',
    0x078: 'HLT',
    0x079: 'INVLPG',
    0x07a: 'INVLPGA',
    0x07b: 'IOIO',
    0x07c: 'MSR',
    0x07d: 'TASK_SWITCH',
    0x07e: 'FERR_FREEZE',
    0x07f: 'SHUTDOWN',
    0x080: 'VMRUN',
    0x081: 'VMMCALL',
    0x082: 'VMLOAD',
    0x083: 'VMSAVE',
    0x084: 'STGI',
    0x085: 'CLGI',
    0x086: 'SKINIT',
    0x087: 'RDTSCP',
    0x088: 'ICEBP',
    0x089: 'WBINVD',
    0x08a: 'MONITOR',
    0x08b: 'MWAIT',
    0x08c: 'MWAIT_COND',
    0x400: 'NPF',
}

vendor_exit_reasons = {
    'vmx': vmx_exit_reasons,
    'svm': svm_exit_reasons,
}

exit_reasons = None

for line in file('/proc/cpuinfo').readlines():
    if line.startswith('flags'):
        for flag in line.split():
            if flag in vendor_exit_reasons:
                exit_reasons = vendor_exit_reasons[flag]

filters = {
    'kvm_exit': ('exit_reason', exit_reasons)
}

def invert(d):
    return dict((x[1], x[0]) for x in d.iteritems())

for f in filters:
    filters[f] = (filters[f][0], invert(filters[f][1]))

import ctypes, struct, array

libc = ctypes.CDLL('libc.so.6')
syscall = libc.syscall
class perf_event_attr(ctypes.Structure):
    _fields_ = [('type', ctypes.c_uint32),
                ('size', ctypes.c_uint32),
                ('config', ctypes.c_uint64),
                ('sample_freq', ctypes.c_uint64),
                ('sample_type', ctypes.c_uint64),
                ('read_format', ctypes.c_uint64),
                ('flags', ctypes.c_uint64),
                ('wakeup_events', ctypes.c_uint32),
                ('bp_type', ctypes.c_uint32),
                ('bp_addr', ctypes.c_uint64),
                ('bp_len', ctypes.c_uint64),
                ]
def _perf_event_open(attr, pid, cpu, group_fd, flags):
    return syscall(298, ctypes.pointer(attr), ctypes.c_int(pid),
                   ctypes.c_int(cpu), ctypes.c_int(group_fd),
                   ctypes.c_long(flags))

PERF_TYPE_HARDWARE			= 0
PERF_TYPE_SOFTWARE			= 1
PERF_TYPE_TRACEPOINT			= 2
PERF_TYPE_HW_CACHE			= 3
PERF_TYPE_RAW				= 4
PERF_TYPE_BREAKPOINT			= 5

PERF_SAMPLE_IP				= 1 << 0
PERF_SAMPLE_TID				= 1 << 1
PERF_SAMPLE_TIME			= 1 << 2
PERF_SAMPLE_ADDR			= 1 << 3
PERF_SAMPLE_READ			= 1 << 4
PERF_SAMPLE_CALLCHAIN			= 1 << 5
PERF_SAMPLE_ID				= 1 << 6
PERF_SAMPLE_CPU				= 1 << 7
PERF_SAMPLE_PERIOD			= 1 << 8
PERF_SAMPLE_STREAM_ID			= 1 << 9
PERF_SAMPLE_RAW				= 1 << 10

PERF_FORMAT_TOTAL_TIME_ENABLED		= 1 << 0
PERF_FORMAT_TOTAL_TIME_RUNNING		= 1 << 1
PERF_FORMAT_ID				= 1 << 2
PERF_FORMAT_GROUP			= 1 << 3

import re

class TracepointProvider(object):
    def __init__(self):
        self.base = '/sys/kernel/debug/tracing/events/kvm/'
        fields = [f
                  for f in os.listdir(self.base)
                  if os.path.isdir(self.base + '/' + f)]
        extra = []
        for f in fields:
            if f in filters:
                subfield, values = filters[f]
                for name, number in values.iteritems():
                    extra.append(f + '(' + name + ')')
        fields += extra
        self.select(fields)
    def fields(self):
        return self._fields
    def select(self, _fields):
        self._fields = _fields
        cpure = r'cpu([0-9]+)'
        self.cpus = [int(re.match(cpure, x).group(1))
                     for x in os.listdir('/sys/devices/system/cpu')
                     if re.match(cpure, x)]
        import resource
        nfiles = len(self.cpus) * 1000
        resource.setrlimit(resource.RLIMIT_NOFILE, (nfiles, nfiles))
        fds = []
        self.group_leaders = []
        for cpu in self.cpus:
            group_leader = -1
            for f in _fields:
                fbase, sub = f, None
                m = re.match(r'(.*)\((.*)\)', f)
                if m:
                    fbase, sub = m.groups()
                attr = perf_event_attr()
                attr.type = PERF_TYPE_TRACEPOINT
                attr.size = ctypes.sizeof(attr)
                id = int(file(self.base + fbase + '/id').read())
                attr.config = id
                attr.sample_type = (PERF_SAMPLE_RAW
                                    | PERF_SAMPLE_TIME
                                    | PERF_SAMPLE_CPU)
                attr.sample_period = 1
                attr.read_format = PERF_FORMAT_GROUP
                fd = _perf_event_open(attr, -1, cpu, group_leader, 0)
                if fd == -1:
                    raise Exception('perf_event_open failed')
                if sub:
                    import fcntl
                    filter = '%s==%d\0' % (filters[fbase][0],
                                         filters[fbase][1][sub])
                    fcntl.ioctl(fd, 0x40082406, filter)
                if group_leader == -1:
                    group_leader = fd
                    fds.append(fd)
            self.group_leaders.append(group_leader)
        self.fds = fds
        self.files = [os.fdopen(group_leader)
                      for group_leader in self.group_leaders]
    def read(self):
        ret = dict([(f, 0) for f in self._fields])
        bytes = 8 * (1 + len(self._fields))
        fmt = 'xxxxxxxx' + 'q' * len(self._fields)
        for file in self.files:
            a = struct.unpack(fmt, file.read(bytes))
            for field, val in zip(self._fields, a):
                ret[field] += val
        return ret

class Stats:
    def __init__(self, provider, fields = None):
        def wanted(key):
            import re
            if not fields:
                return True
            return re.match(fields, key) != None
        self.provider = provider
        self.values = dict([(key, None)
                            for key in provider.fields()
                            if wanted(key)])
        self.provider.select(self.values.keys())
    def get(self):
        new = self.provider.read()
        for key in self.provider.fields():
            oldval = self.values[key]
            newval = new[key]
            newdelta = None
            if oldval is not None:
                newdelta = newval - oldval[0]
            self.values[key] = (newval, newdelta)
        return self.values

if not os.access('/sys/kernel/debug', os.F_OK):
    print 'Please enable CONFIG_DEBUG_FS in your kernel'
    sys.exit(1)
if not os.access('/sys/kernel/debug/kvm', os.F_OK):
    print "Please mount debugfs ('mount -t debugfs debugfs /sys/kernel/debug')"
    print "and ensure the kvm modules are loaded"
    sys.exit(1)

label_width = 40
number_width = 10

def tui(screen, stats):
    curses.use_default_colors()
    curses.noecho()
    def refresh(sleeptime):
        screen.erase()
        screen.addstr(0, 0, 'kvm statistics')
        row = 2
        s = stats.get()
        def sortkey(x):
            if s[x][1]:
                return (-s[x][1], -s[x][0])
            else:
                return (0, -s[x][0])
        for key in sorted(s.keys(), key = sortkey):
            if row >= screen.getmaxyx()[0]:
                break
            values = s[key]
            if not values[0] and not values[1]:
                break
            col = 1
            screen.addstr(row, col, key)
            col += label_width
            screen.addstr(row, col, '%10d' % (values[0],))
            col += number_width
            if values[1] is not None:
                screen.addstr(row, col, '%8d' % (values[1] / sleeptime,))
            row += 1
        screen.refresh()

    sleeptime = 0.25
    while True:
        refresh(sleeptime)
        curses.halfdelay(int(sleeptime * 10))
        sleeptime = 3
        try:
            c = screen.getkey()
            if c == 'q':
                break
        except KeyboardInterrupt:
            break
        except curses.error:
            continue

def batch(stats):
    s = stats.get()
    time.sleep(1)
    s = stats.get()
    for key in sorted(s.keys()):
        values = s[key]
        print '%-22s%10d%10d' % (key, values[0], values[1])

def log(stats):
    keys = sorted(stats.get().iterkeys())
    def banner():
        for k in keys:
            print '%10s' % k[0:9],
        print
    def statline():
        s = stats.get()
        for k in keys:
            print ' %9d' % s[k][1],
        print
    line = 0
    banner_repeat = 20
    while True:
        time.sleep(1)
        if line % banner_repeat == 0:
            banner()
        statline()
        line += 1

options = optparse.OptionParser()
options.add_option('-1', '--once', '--batch',
                   action = 'store_true',
                   default = False,
                   dest = 'once',
                   help = 'run in batch mode for one second',
                   )
options.add_option('-l', '--log',
                   action = 'store_true',
                   default = False,
                   dest = 'log',
                   help = 'run in logging mode (like vmstat)',
                   )
options.add_option('-f', '--fields',
                   action = 'store',
                   default = None,
                   dest = 'fields',
                   help = 'fields to display (regex)',
                   )
(options, args) = options.parse_args(sys.argv)

try:
    provider = TracepointProvider()
except:
    provider = DebugfsProvider()

stats = Stats(provider, fields = options.fields)

if options.log:
    log(stats)
elif not options.once:
    import curses.wrapper
    curses.wrapper(tui, stats)
else:
    batch(stats)
