1D Wasserstein distance in Python

十年热恋 提交于 2020-12-13 05:23:13

问题


The formula below is a special case of the Wasserstein distance/optimal transport when the source and target distributions, x and y (also called marginal distributions) are 1D, that is, are vectors.

where F^{-1} are inverse probability distribution functions of the cumulative distributions of the marginals u and v, derived from real data called x and y, both generated from the normal distribution:

import numpy as np
from numpy.random import randn
import scipy.stats as ss

n = 100
x = randn(n)
y = randn(n)

How can the integral in the formula be coded in python and scipy? I'm guessing the x and y have to be converted to ranked marginals, which are non-negative and sum to 1, while Scipy's ppf could be used to calculate the inverse F^{-1}'s?


回答1:


Note that when n gets large we have that a sorted set of n samples approaches the inverse CDF sampled at 1/n, 2/n, ..., n/n. E.g.:

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
plt.plot(norm.ppf(np.linspace(0, 1, 1000)), label="invcdf")
plt.plot(np.sort(np.random.normal(size=1000)), label="sortsample")
plt.legend()
plt.show()

Also note that your integral from 0 to 1 can be approximated as a sum over 1/n, 2/n, ..., n/n.

Thus we can simply answer your question:

def W(p, u, v):
    assert len(u) == len(v)
    return np.mean(np.abs(np.sort(u) - np.sort(v))**p)**(1/p)

Note that if len(u) != len(v) you can still apply the method with linear interpolation:

def W(p, u, v):
    u = np.sort(u)
    v = np.sort(v)
    if len(u) != len(v):
        if len(u) > len(v): u, v = v, u
        us = np.linspace(0, 1, len(u))
        vs = np.linspace(0, 1, len(v))
        u = np.linalg.interp(u, us, vs)
    return np.mean(np.abs(u - v)**p)**(1/p)

An alternative method if you have prior information about the sort of distribution of your data, but not its parameters, is to find the best fitting distribution on your data (e.g. with scipy.stats.norm.fit) for both u and v and then do the integral with the desired precision. E.g.:

from scipy.stats import norm as gauss
def W_gauss(p, u, v, num_steps):
    ud = gauss(*gauss.fit(u))
    vd = gauss(*gauss.fit(v))
    z = np.linspace(0, 1, num_steps, endpoint=False) + 1/(2*num_steps)
    return np.mean(np.abs(ud.ppf(z) - vd.ppf(z))**p)**(1/p)


来源:https://stackoverflow.com/questions/65175268/1d-wasserstein-distance-in-python

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!