#!/usr/bin/python

# $Id: diskdumpmsg,v 1.8 2006/04/28 23:23:13 tachino Exp $

from os import kill, remove, popen, close, write, open, O_CREAT, O_TRUNC, \
               O_RDWR, environ, sysconf
from os.path import getmtime, exists
from struct import calcsize, unpack
from string import join
from re import compile, sub
import re
from sys import exit, stderr, stdout, version_info
from time import strftime, strptime, localtime
from tempfile import mktemp, mkstemp
from optparse import OptionParser
from zlib import decompress

environ['LANG'] = 'C'

cmdname = 'diskdumpmsg'


# -----------------------------------------------------------------------
#
# Helper functions
#

def err(*msgs):
    """
    print error message to stderr.
    """

    nmsgs = []
    for msg in msgs:
        nmsgs.append(str(msg))

    print >>stderr, ' '.join(nmsgs)

def verr(*msgs):
    """
    print error message to stderr if verbose option is set.
    """

    if options.verbose:
        err(*msgs)

def derr(*msgs):
    """
    print error message to stderr if debug option is set.
    """

    if options.debug:
        err(*msgs)

#
# If mkstemp is not supported, use own function.
#
try:
    mymktemp = mkstemp
except AttributeError:
    def mymktemp(suffix):
        filename = mktemp(suffix)
        fd = open(filename, O_RDWR|O_TRUNC|O_CREAT)
        return (fd, filename)


def is_alive(pidfile):
    """
    Check if the daemon is alive. This is the same logic used by syslogd and
    klogd.
    """

    try:
        fd = file(pidfile)
        pid = int(fd.read())
        kill(pid, 0)
    except:
        return False

    verr('The process is running:', pidfile)
    return True

# -----------------------------------------------------------------------

class SystemMap:
    """
    System.map class.
    """

    def __init__(self, mapfile):
        self.symbols = {}

        for line in file(mapfile):
            items = line.split()
            if len(items) < 3:
                continue

            addr, type, symbol = items[0:3]
            self.symbols[symbol] = (long(addr, 16), type)

    # convert the symbol name to its value.

    def to_val(self, sym):
        return self.symbols[sym][0]


# -----------------------------------------------------------------------

class Vmcore:
    """
    Common vmcore class.
    """

    CMD_ARCH = "/bin/arch"

    bin32on64 = False

    def __init__(self, vmcore, map=None):
        if self == Vmcore:
            raise 'Cannot instanciate from base Vmcore class'

        self.blocksize = sysconf("SC_PAGE_SIZE")
        self.file = vmcore
        self.fd = self.datafd = file(vmcore)
        self.version = False
        self.arch = popen(self.CMD_ARCH).readline().strip()

        # If 32bit python on 64bit kernel, set bin32on64.
        self.bin32on64 = (self.arch == 'ppc64')

    def set_map(self, map):
        """
        If System.map file is not specified, guess by the kernel version
        in vmcore file.
        """
            
        if map == None:
            self.map = SystemMap('/boot/System.map-' + self.get_version())
        else:
            self.map = map

    def v_to_p(self, addr):
        """
        ad hoc virtual to physical address conversion.
        """

        if isinstance(addr, str):
            addr = self.map.to_val(addr)

        # Architecuture dependent virt to phys conversion.
        if self.arch == 'ia64':
            if 0xa000000000000000L <= addr and addr < 0xc000000000000000L:
                paddr = (addr & 0x0fffffffffffffffL) - 0xfc000000L
            elif 0xc000000000000000L <= addr:
                paddr = addr & 0x0fffffffffffffffL
        elif self.arch == 'x86_64':
            paddr = addr & 0x000000007fffffffL
        elif self.arch == 'i386' or self.arch == 'i686':
            paddr = addr - 0xc0000000L
        elif self.arch == 'ppc64':
            paddr = addr & 0x0fffffffffffffffL

        return paddr

    def v_to_off(self, addr):
        """
        convert the virtual address or symbol to file offset.
        """

        return self.p_to_off(self.v_to_p(addr))

    def readlong(self, addr):
        """
        read long value at the specified virtual address or symbol from vmcore.
        """

        if self.bin32on64:
            packstr = "Q"
        else:
            packstr = "L"
        calcsizeL = calcsize(packstr)

        offset = self.v_to_off(addr)

        derr('read %s => %lx' % (str(addr), offset))

        self.datafd.seek(offset)
        data = self.datafd.read(calcsizeL)
        if len(data) != calcsizeL:
            raise 'read error: %d expected %d ' % (len(data), calcsizeL)

        return unpack(packstr, data)[0]

    def readint(self, addr):
        """
        read int value at the specified virtual address or symbol from vmcore.
        """

        packstr = "i"
        calcsizei = calcsize(packstr)

        offset = self.v_to_off(addr)

        derr('read %s => %lx' % (str(addr), offset))

        self.datafd.seek(offset)
        data = self.datafd.read(calcsizei)
        if len(data) != calcsizei:
            raise 'read error: %d expected %d ' % (len(data), calcsizei)

        return unpack(packstr, data)[0]

    def slice(self, addr, length):
        """
        read memory area which start with the specified address.
        """

        self.datafd.seek(addr)
        data = self.datafd.read(length)
        if len(data) != length:
            raise 'read error: %d expected %d ' % (len(data), length)

        return data


    def timeofdeath(self):
        """
        get time of death of this vmcore.
        """

        return self.readlong('xtime')

    def get_logbuf(self):
        """
        get logbuf area.
        """

        log_buf = self.readlong('log_buf')
        log_buf_offset = self.v_to_off(log_buf)
        log_buf_len = self.readint('log_buf_len')

        # sanity check
        if log_buf_len > 1048576 * 16:
            raise "log_buf_len(%d) is too large" % log_buf_len
            

        # start and end marker
        log_start = self.readlong('log_start')
        log_end = self.readlong('log_end')

        derr('log_buf_offset %lx' % log_buf)
        derr('log_buf_len %ld' % log_buf_len)
        derr('log_start %ld' % log_start)
        derr('log_end %ld' % log_end)

        if log_end < log_buf_len:
            # log_bug is not filled by messages.

            log =  self.slice(log_buf_offset, log_end)
        else:
            # log_bug is filled by messages.
            # log_end points the oldest message in log_buf.

            log_end = log_end % log_buf_len
            log = self.slice(log_buf_offset + log_end, log_buf_len - log_end)+\
                  self.slice(log_buf_offset, log_end)

        # delete priority labels
        return sub(compile(r'^<\d\>', re.M), '', log)

    def get_hostname(self):
        """
        Get nodename from vmcore. gethostname() cannot be used because
        network may be not activated when this script is executed.
        """

        utsname_offset = self.v_to_off('system_utsname')
        utsname = self.slice(utsname_offset, 130)

        # struct new_utsname {
        #       char sysname[65];
        #       char nodename[65];
        #       ..

        nodename = utsname[65:]

        pos = nodename.find('.')
        if pos >= 0:
            nodename = nodename[0:pos]
        nodename = nodename.strip('\x00')

        return nodename

    def generate(cls, vmcore, map=None):
        """
        Inspect the vmcore and generate the instance of apropriate subclass
        """
        for subclass in [ElfVmcore, CompressedVmcore]:
            if subclass.match(vmcore):
                return subclass(vmcore, map)

    generate = classmethod(generate)

# -----------------------------------------------------------------------
#
# Elf vmcore

class ElfProgHeader:
    """
    makeshift ELF Program Header class
    """
    PT_LOAD = 1

    def __init__(self, line):
        pass
        # print
        # print "type:", self.type
        # print "offset:", self.offset
        # print "vaddr:", self.vaddr
        # print "paddr:", self.paddr
        # print "filesz:", self.filesz
        # print "memsz:", self.memsz
        # print "flags:", self.flags
        # print "align:", self.align

class ElfProgHeader32(ElfProgHeader):
    packstr = 'IIIIIIII'

    def __init__(self, line):
        (self.type, self.offset, self.vaddr, self.paddr, self.filesz, \
         self.memsz, self.flags, self.align) = unpack(self.packstr, line)

        ElfProgHeader.__init__(self, line)

class ElfProgHeader64(ElfProgHeader):
    packstr = 'IIQQQQQQ'

    def __init__(self, line):
        (self.type, self.flags, self.offset, self.vaddr, self.paddr,
         self.filesz, self.memsz, self.align) = unpack(self.packstr, line)

        ElfProgHeader.__init__(self, line)


class Elfheader:
    """
    makeshift ELF Header class
    """
    def __init__(self, filename):
        fd = file(filename)

        header = fd.read(sysconf("SC_PAGE_SIZE"))

        ident = header[0:5]

        if ident == '\x7fELF\x01':
            print '32bit elf'
            self.bit64 = False
        elif ident == '\x7fELF\x02':
            print '64bit elf'
            self.bit64 = True
        else:
            raise 'Unsupported format'

        if self.bit64:
            packstr = '16sHHIQQQIHHHHHH'
        else:
            packstr = '16sHHIIIIIHHHHHH'

        (self.ident, self.type, self.machine, self.version, self.entry, \
         self.phoff, self.shoff, self.flags, self.ehsize, self.phentsize, \
         self.phnum, self.shentsize, self.shnum, self.shstrndx) = \
                unpack(packstr, header[0 : calcsize(packstr)])

        # print "ident:", self.ident
        # print "type:", self.type
        # print "machine:", self.machine
        # print "version:", self.version
        # print "entry:", self.entry
        # print "phoff:", self.phoff
        # print "shoff:", self.shoff
        # print "flags:", self.flags
        # print "ehsize:", self.ehsize
        # print "phentsize:", self.phentsize
        # print "phnum:", self.phnum
        # print "shentsize:", self.shentsize
        # print "shnum:", self.shnum
        # print "shstrndx:", self.shstrndx

        self.prog_headers = []

        if self.bit64:
            progheaderclass = ElfProgHeader64
        else:
            progheaderclass = ElfProgHeader32

        for i in range(self.phnum):
            ph = progheaderclass(header[self.phoff + i * self.phentsize : self.phoff + (i + 1) * self.phentsize])
            if ph.type == ph.PT_LOAD:
                self.prog_headers.append(ph)

class Ptload:
    """
    PT_LOAD Segment class
    """

    re_load = compile(r"load\d+\s")

    def __init__(self, arg, vmcore):
        if isinstance(arg, str):
            if not self.re_load.search(arg):
                raise 'bad statement'

            items = arg.split()
            if len(items) < 6:
                raise 'bad statement'

            (self.idx, self.name, size, vma, self.lma, offs) = items[:6]
            self.vma = long(vma, 16)
            self.size = long(size, 16)
            self.offs = long(offs, 16)
        else:
            self.idx = '0'
            self.name = 'loadl'
            self.lma = '0'
            self.vma = arg.vaddr
            self.size = arg.memsz
            self.offs = arg.offset

        self.vmaend = self.vma + self.size
        self.vmcore = vmcore

    def includes(self, paddr):
        """
        is the address specified included in this PT_LOAD segment?
        """

        return self.vmcore.v_to_p(self.vma) <= paddr and \
               paddr < self.vmcore.v_to_p(self.vmaend)

    def to_offset(self, paddr):
        """
        convert the vitual address to a file offset according this PT_LOAD.
        """

        return self.offs + (paddr - self.vmcore.v_to_p(self.vma))

class ElfVmcore(Vmcore):
    """
    Elf vmcore class
    """

    CMD_OBJDUMP = "/usr/bin/objdump -h %s 2>/dev/null"

    def __init__(self, vmcore, map=None):
        Vmcore.__init__(self, vmcore, map)

        # read ELF PT_LOAD segments information

        self.loadsegs = []
        for seg in popen(self.CMD_OBJDUMP % vmcore).xreadlines():
            try:
                self.loadsegs.append(Ptload(seg, self))
            except:
                pass

        # on 32bit system, large vmcore cannot be analyzed by
        # objdump. Try to read Elf section by myself.
        #
        if len(self.loadsegs) == 0:
            elfheader = Elfheader(vmcore)

            for ph in elfheader.prog_headers:
                self.loadsegs.append(Ptload(ph, self))

        for seg in self.loadsegs:
            derr("load %lx(%lx)-%lx(%lx) %lx" % \
                 (seg.vma, self.v_to_p(seg.vma), seg.vmaend, \
                 self.v_to_p(seg.vmaend), seg.offs))

        if len(self.loadsegs) < 1:
            raise 'vmcore has no PT_LOAD segment'

        self.set_map(map)


    def get_version(self):
        """
        get the kernel version by digging vmcore.
        """

        if self.version:
            return self.version

        self.datafd.seek(0)
        re_ver = compile(r"Linux version (\S+)")
        for l in self.datafd:
            match = re_ver.search(l)
            if match:
                self.version = match.group(1)
                return self.version

        raise 'version not found'

    def p_to_off(self, paddr):
        """
        convert the physical address to file offset.
        """

        for load in self.loadsegs:
            derr('%lx-%lx %lx' % (load.vma, load.vmaend, paddr))

            if load.includes(paddr):
                return load.to_offset(paddr)

        raise 'offset error'

    def match(cls, vmcore):
        return file(vmcore).read(4) == '\x7fELF'

    match = classmethod(match)

# -----------------------------------------------------------------------
#
# Compressed vmcore
#

def divideup(x, y):
    return (x + y - 1) / y

class CompressedVmcoreHeader:
    """
    Compressed vmcore header class.
    """

    def __init__(self, line, bin32on64=False):
        if bin32on64:
            packstr = '8si65s65s65s65s65s65s2sqqIiiIIIIIIi'
        else:
            packstr = '8si65s65s65s65s65s65s2sllIiiIIIIIIi'

        (self.signature, self.header_version, sysname, nodename, \
         release, version, machine, domainname, \
         self.dummy, self.sec, self.usec, self.status, self.block_size, \
         self.sub_hdr_size, self.bitmap_blocks, self.max_mapnr, \
         self.total_ram_blocks, self.device_blocks, self.written_blocks, \
         self.current_cpu, self.nr_cpus) = \
             unpack(packstr, line[0:calcsize(packstr)])

        self.sysname = sysname.strip('\x00')
        self.nodename = nodename.strip('\x00')
        self.release = release.strip('\x00')
        self.version = version.strip('\x00')
        self.machine = machine.strip('\x00')
        self.domainname = domainname.strip('\x00')

class PageDesc:
    """
    Page Descriptor class
    """

    packstr = "qIIQ"
    COMPRESSED = 0x1

    def size(self):
        return calcsize(self.packstr)

    size = classmethod(size)

    def __init__(self, line):
         (self.offset, self.size, self.flags, self.page_flags) = \
                 unpack(self.packstr, line[0:PageDesc.size()])


class CompressedVmcore(Vmcore):
    """
    Compressed vmcore class.
    """

    def __init__(self, vmcore, map=None):
        Vmcore.__init__(self, vmcore, map)

        self.read_headers()
        self.set_map(map)

        phys_text_address = self.v_to_p('_stext')
        phys_end_address = self.v_to_p('_end')

        self.block_start = phys_text_address / self.blocksize
        self.block_end = divideup(phys_end_address, self.blocksize)

        # sanity check

        sz = self.block_end - self.block_start
        if sz < 0 or sz > 100000:
            raise 'too large data %d' % sz

        self.datafilename = vmcore + '-uncompressedrawdata'

        self.memory_dump(self.datafilename)
       
        self.datafd = file(self.datafilename)

    def __del__(self):
        remove(self.datafilename)

    def get_version(self):
        return self.header.release

    def read_blocks(self, nr):
        return self.fd.read(self.blocksize * nr)

    def is_partial(self):
        return self.header.bitmap_blocks >= divideup(divideup(self.header.max_mapnr, 8), self.blocksize) * 2

    def read_headers(self):

        # read header, sub_header, bitmap, dumpable_bitmap.

        self.rawheader = self.read_blocks(1)
        self.header = CompressedVmcoreHeader(self.rawheader, self.bin32on64)

        if self.header.signature != 'DISKDUMP':
            raise "not a compressed vmcore"
            
        self.sub_header = self.read_blocks(self.header.sub_hdr_size)
        self.bitmap = self.read_blocks(self.header.bitmap_blocks)

        if self.is_partial():
            bitmap_len = self.header.bitmap_blocks * self.blocksize
            self.dumpable_bitmap = self.bitmap[bitmap_len / 2:bitmap_len]
        else:
            self.dumpable_bitmap = self.bitmap[:]

        # page aescriptor array starts here

        self.pd_offset = self.blocksize * (1 + self.header.sub_hdr_size + self.header.bitmap_blocks)

    def page_is_ram(self, nr):
            return unpack('B', self.bitmap[nr>>3])[0] & (1 << (nr & 7))

    def page_is_dumpable(self, nr):
            return unpack('B', self.dumpable_bitmap[nr>>3])[0] & (1 << (nr & 7))

    def memory_dump(self, filename):
        """
        Uncompress text/bss/data area and write it to the specified file.
        """

        # slurp page descriptor array

        pd_size = PageDesc.size()

        ofile = file(filename, "w")

        self.fd.seek(self.pd_offset)
        page_desc_raw = self.fd.read(pd_size * self.header.max_mapnr)

        # read each pages and write it to the temporary file. If compressed,
        # write uncompressed data.

        idx = 0
        for pfn in xrange(self.block_start):
            if self.page_is_dumpable(pfn):
                idx = idx + 1

        for pfn in xrange(self.block_start, self.block_end):

            # If the page does not exist, fill with zero.
            if not self.page_is_dumpable(pfn):
                ofile.seek(self.blocksize, 1)
                continue

            page_desc = PageDesc(page_desc_raw[idx * pd_size : (idx + 1) * pd_size ])
            self.fd.seek(page_desc.offset)
            page_data = self.fd.read(page_desc.size)

            if page_desc.flags & PageDesc.COMPRESSED:
                page_data = decompress(page_data)

            ofile.write(page_data)
            idx = idx + 1

        ofile.close()

    def p_to_off(self, paddr):
        """
        convert the physical address to file offset.
        """

        return paddr - self.block_start * self.blocksize

    def match(cls, vmcore):
        return file(vmcore).read(8) == 'DISKDUMP'

    match = classmethod(match)


# -----------------------------------------------------------------------
#
# /var/log/messages
#

class Message:
    """
    message line class
    """

    def __init__(self, line):
        items = line.split(None, 5)

        self.wholemessage = line
        self.date = None
        self.host = None
        self.ident = None
        self.body = ""
        self.badmessage = True

        if len(items) < 5:
            # not a standard message format
            return

        timemes = join(items[:3] + [strftime("%Y")], ' ')
        try:
            date = strptime(timemes, "%b %d %H:%M:%S %Y")

            self.badmessage = False
            self.date = date
            self.host = items[3]
            self.ident = items[4]
            if len(items) > 5:
                self.body = items[5]

        except ValueError:
            pass


class Messages:
    """
    /var/log/messages class
    """

    CMD_DIFF = "/usr/bin/diff -u %s %s"

    def __init__(self, messagefile, dry_run=False, force=False):
        self.messagefile = messagefile
        self.messages = []

        if not exists(messagefile) and force:
            file(messagefile, 'w')

        for line in file(messagefile):
            self.messages.append(Message(line))


    def kernelmessages(self):
        """
        return kernel messages only.
        """

        kmes = []
        for mes in self.messages:
            if mes.badmessage:
                # print "<BAD>", mes.wholemessage,
                continue
            if mes.ident != 'kernel:':
                continue

            # print strftime("%b %d %H:%I:%S %Y", mes.date),
            kmes.append(mes.body)

        return ''.join(kmes)


    def kernelmessages_after_reboot(self):
        """
        return kernel messages after the latest reboot.
        """

        msgs = self.kernelmessages()

        pos = msgs.rfind('\nLinux version ')
        if pos >= 0:
            msgs = msgs[pos+1:]

        return msgs

    def diff_messages(self, message, logbuf):
        """
        compare log_buf and kernel messages in /var/log/messages and
        retrieve unrecorded message.
        """

        (fd_mes, file_mes) = mymktemp('.diskdump')
        (fd_log, file_log) = mymktemp('.diskdump')

        write(fd_mes, join(message, ""))
        write(fd_log, join(logbuf, ""))

        close(fd_mes)
        close(fd_log)

        diff = popen(self.CMD_DIFF % (file_mes, file_log)).readlines()

        if not options.debug:
            remove(file_mes)
            remove(file_log)

        # get the last '+' block from output of diff
        i = len(diff) - 1
        for i in xrange(len(diff) - 1, -1, -1):
            if diff[i].find('+') != 0:
                break

        # delete '+' from each messages.
        return [line[1:] for line in diff[i+1:]]

    def unrecorded_messages(self, logbuf):
        """
        compare log_buf and kernel messages in /var/log/messages and
        retrieve unrecorded message
        """

        kmessages = self.kernelmessages_after_reboot()

        return self.diff_messages(kmessages, logbuf)

    def isnewer(self, date):
        return getmtime(self.messagefile) - 5 > date

    def complement(self, logbuf, hostname, dry_run=True):
        unrecorded = self.unrecorded_messages(logbuf)

        if len(unrecorded) == 0 or unrecorded == '\n':
            verr('no unrecorded message exists')
            return

        tod = strftime("%b %d %H:%M:%S", localtime(timeofdeath))

        if dry_run:
            fd = stdout
        else:
            fd = file(self.messagefile, "a")

        print >>fd, tod, hostname, \
            'kernel:', "--- salvaged messages from crash dump start"

        for l in unrecorded:
            print >>fd, tod, hostname, 'kernel:', l,

        print >>fd, tod, hostname, \
            'kernel:', "--- salvaged messages from crash dump end"

# -----------------------------------------------------------------------
#
# Main
#

# parse options -f -v --dry-run -d

parser = OptionParser("diskdumpmsg [options] vmcore messagefile")
parser.add_option("-f", "--force", action="store_true", \
                  help="force to append salvaged messages")
parser.add_option("-v", "--verbose", action="store_true", \
                  help="print verbose messages")
parser.add_option("", "--dry-run", action="store_true", \
                  help="dry run")
parser.add_option("-d", "--debug", action="store_true", \
                  help="run with debug mode")
(options, args) = parser.parse_args()

if len(args) < 1:
    err('Usage: ' + cmdname + ' vmcore message')
    exit(1)

vmcorefile = args[0]

if len(args) >= 2:
    messagefile = args[1]
else:
    messagefile = "/var/log/messages"

# If syslogd or klogd is running, do nothing

if not options.force:
    if is_alive('/var/run/syslogd.pid') or is_alive('/var/run/klogd.pid'):
        verr('syslogd or klogd is running')
        exit(1)

vmcore = Vmcore.generate(vmcorefile)
logbuf = vmcore.get_logbuf()

timeofdeath = vmcore.timeofdeath()

messages = Messages(messagefile, force=options.force, dry_run=options.dry_run)

verr('last time of death:', timeofdeath)

if messages.isnewer(timeofdeath) and not options.force:
    verr('vmcore is older than messages in /var/log/messages:', timeofdeath)
    exit(0)

messages.complement(logbuf, vmcore.get_hostname(), dry_run=options.dry_run)

# call destructor to remove temporary file.
vmcore = None
