#!/usr/bin/env python

"""
A simple utility that compares tool invocations and exit codes issued by
compiler drivers that support -### (e.g. gcc and clang).
"""

from itertools import zip_longest
import subprocess

def splitArgs(s):
    it = iter(s)
    current = ''
    inQuote = False
    for c in it:
        if c == '"':
            if inQuote:
                inQuote = False
                yield current + '"'
            else:
                inQuote = True
                current = '"'
        elif inQuote:
            if c == '\\':
                current += c
                current += next(it)
            else:
                current += c
        elif not c.isspace():
            yield c

def insertMinimumPadding(a, b, dist):
    """insertMinimumPadding(a,b) -> (a',b')

    Return two lists of equal length, where some number of Nones have
    been inserted into the shorter list such that sum(map(dist, a',
    b')) is minimized.

    Assumes dist(X, Y) -> int and non-negative.
    """
    
    def cost(a, b):
        return sum(map(dist, a + [None] * (len(b) - len(a)), b))

    # Normalize so a is shortest.
    if len(b) < len(a):
        b, a = insertMinimumPadding(b, a, dist)
        return a,b

    # For each None we have to insert...
    for i in range(len(b) - len(a)):
        # For each position we could insert it...
        current = cost(a, b)
        best = None
        for j in range(len(a) + 1):
            a_0 = a[:j] + [None] + a[j:]
            candidate = cost(a_0, b)
            if best is None or candidate < best[0]:
                best = (candidate, a_0, j)
        a = best[1]
    return a,b

class ZipperDiff(object):
    """ZipperDiff - Simple (slow) diff only accommodating inserts."""
    
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def dist(self, a, b):
        return a != b

    def getDiffs(self):
        a,b =  insertMinimumPadding(self.a, self.b, self.dist)
        for aElt,bElt in zip(a,b):
            if self.dist(aElt, bElt):
                yield aElt,bElt

class DriverZipperDiff(ZipperDiff):
    def isTempFile(self, filename):
        if filename[0] != '"' or filename[-1] != '"':
            return False
        return (filename.startswith('/tmp/', 1) or
                filename.startswith('/var/', 1))

    def dist(self, a, b):
        if a and b and self.isTempFile(a) and self.isTempFile(b):
            return 0
        return super(DriverZipperDiff, self).dist(a,b)        

class CompileInfo:
    def __init__(self, out, err, res):
        self.commands = []
        
        # Standard out isn't used for much.
        self.stdout = out
        self.stderr = ''

        # FIXME: Compare error messages as well.
        for ln in err.split('\n'):
            if (ln == 'Using built-in specs.' or
                ln.startswith('Target: ') or
                ln.startswith('Configured with: ') or
                ln.startswith('Thread model: ') or
                ln.startswith('gcc version') or
                ln.startswith('clang version')):
                pass
            elif ln.strip().startswith('"'):
                self.commands.append(list(splitArgs(ln)))
            else:
                self.stderr += ln + '\n'
        
        self.stderr = self.stderr.strip()
        self.exitCode = res

def captureDriverInfo(cmd, args):
    p = subprocess.Popen([cmd,'-###'] + args,
                         stdin=None,
                         stdout=subprocess.PIPE,
                         stderr=subprocess.PIPE)
    out,err = p.communicate()
    res = p.wait()
    return CompileInfo(out,err,res)

def main():
    import os, sys

    args = sys.argv[1:]
    driverA = os.getenv('DRIVER_A') or 'gcc'
    driverB = os.getenv('DRIVER_B') or 'clang'

    infoA = captureDriverInfo(driverA, args)
    infoB = captureDriverInfo(driverB, args)

    differ = False

    # Compare stdout.
    if infoA.stdout != infoB.stdout:
        print('-- STDOUT DIFFERS -')
        print('A OUTPUT: ',infoA.stdout)
        print('B OUTPUT: ',infoB.stdout)
        print()

        diff = ZipperDiff(infoA.stdout.split('\n'),
                          infoB.stdout.split('\n'))
        for i,(aElt,bElt) in enumerate(diff.getDiffs()):
            if aElt is None:
                print('A missing: %s' % bElt)
            elif bElt is None:
                print('B missing: %s' % aElt)
            else:
                print('mismatch: A: %s' % aElt)
                print('          B: %s' % bElt)

        differ = True

    # Compare stderr.
    if infoA.stderr != infoB.stderr:
        print('-- STDERR DIFFERS -')
        print('A STDERR: ',infoA.stderr)
        print('B STDERR: ',infoB.stderr)
        print()

        diff = ZipperDiff(infoA.stderr.split('\n'),
                          infoB.stderr.split('\n'))
        for i,(aElt,bElt) in enumerate(diff.getDiffs()):
            if aElt is None:
                print('A missing: %s' % bElt)
            elif bElt is None:
                print('B missing: %s' % aElt)
            else:
                print('mismatch: A: %s' % aElt)
                print('          B: %s' % bElt)

        differ = True

    # Compare commands.
    for i,(a,b) in enumerate(zip_longest(infoA.commands, infoB.commands, fillvalue=None)):
        if a is None:
            print('A MISSING:',' '.join(b))
            differ = True
            continue
        elif b is None:
            print('B MISSING:',' '.join(a))
            differ = True
            continue

        diff = DriverZipperDiff(a,b)
        diffs = list(diff.getDiffs())
        if diffs:
            print('-- COMMAND %d DIFFERS -' % i)
            print('A COMMAND:',' '.join(a))
            print('B COMMAND:',' '.join(b))
            print()
            for i,(aElt,bElt) in enumerate(diffs):
                if aElt is None:
                    print('A missing: %s' % bElt)
                elif bElt is None:
                    print('B missing: %s' % aElt)
                else:
                    print('mismatch: A: %s' % aElt)
                    print('          B: %s' % bElt)
            differ = True
    
    # Compare result codes.
    if infoA.exitCode != infoB.exitCode:
        print('-- EXIT CODES DIFFER -')
        print('A: ',infoA.exitCode)
        print('B: ',infoB.exitCode)
        differ = True

    if differ:
        sys.exit(1)

if __name__ == '__main__':
    main()
