Suggestions on how to speed up a distance calculation

久未见 提交于 2019-11-30 09:53:14

The following cython code (I realize the first line of __init__ is different, I replaced it with random stuff because I don't know var and because it doesn't matter anyway - you stated __call__ is the bottleneck):

cdef class SquareErrorDistance:
    cdef double _norm

    def __init__(self, dataSample):
        variance = round(sum(dataSample)/len(dataSample))
        if variance == 0:
            self._norm = 1.0
        else:
            self._norm = 1.0 / (2 * variance)

    def __call__(self, double u, double v): # u and v are floats
        return (u - v) ** 2 * self._norm

Compiled via a simple setup.py (just the example from the docs with the file name altered), it performs nearly 20 times better than the equivalent pure python in a simple contrieved timeit benchmark. Note that the only changed were cdefs for the _norm field and the __call__ parameters. I consider this pretty impressive.

This probably won't help much, but you can rewrite it using nested functions:

def SquareErrorDistance(dataSample):
    variance = var(list(dataSample))
    if variance == 0:
        def f(u, v):
            x = u - v
            return x * x
    else:
        norm = 1.0 / (2 * variance)
        def f(u, v):
            x = u - v
            return x * x * norm
    return f
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!