Developing a heuristic to test simple anonymous Python functions for equivalency

☆樱花仙子☆ 提交于 2019-12-02 20:51:39

Edited to check whether external state will affect the sorting function as well as if the two functions are equivalent.


I hacked up dis.dis and friends to output to a global file-like object. I then stripped out line numbers and normalized variable names (without touching constants) and compared the result.

You could clean this up so dis.dis and friends yielded out lines so you wouldn't have to trap their output. But this is a working proof-of-concept for using dis.dis for function comparison with minimal changes.

import types
from opcode import *
_have_code = (types.MethodType, types.FunctionType, types.CodeType,
              types.ClassType, type)

def dis(x):
    """Disassemble classes, methods, functions, or code.

    With no argument, disassemble the last traceback.

    """
    if isinstance(x, types.InstanceType):
        x = x.__class__
    if hasattr(x, 'im_func'):
        x = x.im_func
    if hasattr(x, 'func_code'):
        x = x.func_code
    if hasattr(x, '__dict__'):
        items = x.__dict__.items()
        items.sort()
        for name, x1 in items:
            if isinstance(x1, _have_code):
                print >> out,  "Disassembly of %s:" % name
                try:
                    dis(x1)
                except TypeError, msg:
                    print >> out,  "Sorry:", msg
                print >> out
    elif hasattr(x, 'co_code'):
        disassemble(x)
    elif isinstance(x, str):
        disassemble_string(x)
    else:
        raise TypeError, \
              "don't know how to disassemble %s objects" % \
              type(x).__name__

def disassemble(co, lasti=-1):
    """Disassemble a code object."""
    code = co.co_code
    labels = findlabels(code)
    linestarts = dict(findlinestarts(co))
    n = len(code)
    i = 0
    extended_arg = 0
    free = None
    while i < n:
        c = code[i]
        op = ord(c)
        if i in linestarts:
            if i > 0:
                print >> out
            print >> out,  "%3d" % linestarts[i],
        else:
            print >> out,  '   ',

        if i == lasti: print >> out,  '-->',
        else: print >> out,  '   ',
        if i in labels: print >> out,  '>>',
        else: print >> out,  '  ',
        print >> out,  repr(i).rjust(4),
        print >> out,  opname[op].ljust(20),
        i = i+1
        if op >= HAVE_ARGUMENT:
            oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
            extended_arg = 0
            i = i+2
            if op == EXTENDED_ARG:
                extended_arg = oparg*65536L
            print >> out,  repr(oparg).rjust(5),
            if op in hasconst:
                print >> out,  '(' + repr(co.co_consts[oparg]) + ')',
            elif op in hasname:
                print >> out,  '(' + co.co_names[oparg] + ')',
            elif op in hasjrel:
                print >> out,  '(to ' + repr(i + oparg) + ')',
            elif op in haslocal:
                print >> out,  '(' + co.co_varnames[oparg] + ')',
            elif op in hascompare:
                print >> out,  '(' + cmp_op[oparg] + ')',
            elif op in hasfree:
                if free is None:
                    free = co.co_cellvars + co.co_freevars
                print >> out,  '(' + free[oparg] + ')',
        print >> out

def disassemble_string(code, lasti=-1, varnames=None, names=None,
                       constants=None):
    labels = findlabels(code)
    n = len(code)
    i = 0
    while i < n:
        c = code[i]
        op = ord(c)
        if i == lasti: print >> out,  '-->',
        else: print >> out,  '   ',
        if i in labels: print >> out,  '>>',
        else: print >> out,  '  ',
        print >> out,  repr(i).rjust(4),
        print >> out,  opname[op].ljust(15),
        i = i+1
        if op >= HAVE_ARGUMENT:
            oparg = ord(code[i]) + ord(code[i+1])*256
            i = i+2
            print >> out,  repr(oparg).rjust(5),
            if op in hasconst:
                if constants:
                    print >> out,  '(' + repr(constants[oparg]) + ')',
                else:
                    print >> out,  '(%d)'%oparg,
            elif op in hasname:
                if names is not None:
                    print >> out,  '(' + names[oparg] + ')',
                else:
                    print >> out,  '(%d)'%oparg,
            elif op in hasjrel:
                print >> out,  '(to ' + repr(i + oparg) + ')',
            elif op in haslocal:
                if varnames:
                    print >> out,  '(' + varnames[oparg] + ')',
                else:
                    print >> out,  '(%d)' % oparg,
            elif op in hascompare:
                print >> out,  '(' + cmp_op[oparg] + ')',
        print >> out

def findlabels(code):
    """Detect all offsets in a byte code which are jump targets.

    Return the list of offsets.

    """
    labels = []
    n = len(code)
    i = 0
    while i < n:
        c = code[i]
        op = ord(c)
        i = i+1
        if op >= HAVE_ARGUMENT:
            oparg = ord(code[i]) + ord(code[i+1])*256
            i = i+2
            label = -1
            if op in hasjrel:
                label = i+oparg
            elif op in hasjabs:
                label = oparg
            if label >= 0:
                if label not in labels:
                    labels.append(label)
    return labels

def findlinestarts(code):
    """Find the offsets in a byte code which are start of lines in the source.

    Generate pairs (offset, lineno) as described in Python/compile.c.

    """
    byte_increments = [ord(c) for c in code.co_lnotab[0::2]]
    line_increments = [ord(c) for c in code.co_lnotab[1::2]]

    lastlineno = None
    lineno = code.co_firstlineno
    addr = 0
    for byte_incr, line_incr in zip(byte_increments, line_increments):
        if byte_incr:
            if lineno != lastlineno:
                yield (addr, lineno)
                lastlineno = lineno
            addr += byte_incr
        lineno += line_incr
    if lineno != lastlineno:
        yield (addr, lineno)

class FakeFile(object):
    def __init__(self):
        self.store = []
    def write(self, data):
        self.store.append(data)

a = lambda x : x
b = lambda x : x # True
c = lambda x : 2 * x
d = lambda y : 2 * y # True
e = lambda x : 2 * x
f = lambda x : x * 2 # True or False is fine, but must be stable
g = lambda x : 2 * x
h = lambda x : x + x # True or False is fine, but must be stable

funcs = a, b, c, d, e, f, g, h

outs = []
for func in funcs:
    out = FakeFile()
    dis(func)
    outs.append(out.store)

import ast

def outfilter(out):
    for i in out:
        if i.strip().isdigit():
            continue
        if '(' in i:
            try:
                ast.literal_eval(i)
            except ValueError:
                i = "(x)"
        yield i

processed_outs = [(out, 'LOAD_GLOBAL' in out or 'LOAD_DECREF' in out)
                            for out in (''.join(outfilter(out)) for out in outs)]

for (out1, polluted1), (out2, polluted2) in zip(processed_outs[::2], processed_outs[1::2]):
    print 'Bytecode Equivalent:', out1 == out2, '\nPolluted by state:', polluted1 or polluted2

The output is True, True, False, and False and is stable. The "Polluted" bool is true if the output will depend on external state -- either global state or a closure.

So, let's address some technical issues first.

1) Byte code: it is probably not an problem because, instead of inspecting the pyc (the binary files), you can use dis module to get the "bytecode". e.g.

>>> f = lambda x, y : x+y
>>> dis.dis(f)
  1           0 LOAD_FAST                0 (x)
              3 LOAD_FAST                1 (y)
              6 BINARY_ADD          
              7 RETURN_VALUE 

No need to worry about platform.

2) Tokenized source code. Again python has all you need to do the job. You can use the ast module to parse the code and obtain the ast.

>>> a = ast.parse("f = lambda x, y : x+y")
>>> ast.dump(a)
"Module(body=[Assign(targets=[Name(id='f', ctx=Store())], value=Lambda(args=arguments(args=[Name(id='x', ctx=Param()), Name(id='y', ctx=Param())], vararg=None, kwarg=None, defaults=[]), body=BinOp(left=Name(id='x', ctx=Load()), op=Add(), right=Name(id='y', ctx=Load()))))])"

So, the question we should really address is: is it feasible to determine that two functions are equivalent analytically?

It is easy for human to say 2*x equals to x+x, but how can we create an algorithm to prove it?

If it is what you want to achieve, you may want to check this out: http://en.wikipedia.org/wiki/Computer-assisted_proof

However, if ultimately you simply want to assert two different data set are sorted in the same order, you just need to run the sort function A on dataset B and vice versa, and then check the outcome. If they are identical, then the functions are probably functionally identical. Of course, the check is only valid for the said datasets.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!