How can I subsample an array according to its density? (Remove frequent values, keep rare ones)

主宰稳场 提交于 2019-12-28 16:24:15


I have this problem that I want to plot a data distribution where some values occur frequently while others are quite rare. The number of points in total is around 30.000. Rendering such a plot as png or (god forbid) pdf takes forever and the pdf is much too large to display.

So I want to subsample the data just for the plots. What I would like to achieve is to remove a lot of points where they overlap (where the density is high), but keep the ones where the density is low with almost probability 1.

Now, numpy.random.choice allows one to specify a vector of probabilities, which I've computed according to the data histogram with a few tweaks. But I can't seem to get my choice so that the rare points are really kept.

I've attached an image of the data; the right tail of the distribution has orders of magnitude fewer points, so I'd like to keep those. The data is 3d, but the density comes from only one dimension, so I can use that as a measure for how many points are in a given location


One possible approach is using kernel density estimation (KDE) to build an estimated probability distribution of the data, then sample according to the inverse of the estimated probability density of each point (or some other function that becomes smaller the bigger the estimated probability density is). There are a few tools to compute a (KDE) in Python, a simple one is scipy.stats.gaussian_kde. Here is an example of the idea:

import numpy as np
import scipy.stats
import matplotlib.pyplot as plt

# Make some random Gaussian data
data = np.random.multivariate_normal([1, 1], [[1, 0], [0, 1]], size=1000)
# Compute KDE
kde = scipy.stats.gaussian_kde(data.T)
# Choice probabilities are computed from inverse probability density in KDE
p = 1 / kde.pdf(data.T)
# Normalize choice probabilities
p /= np.sum(p)
# Make sample using choice probabilities
idx = np.random.choice(np.arange(len(data)), size=100, replace=False, p=p)
sample = data[idx]
# Plot
plt.scatter(data[:, 0], data[:, 1], label='Data', s=10)
plt.scatter(sample[:, 0], sample[:, 1], label='Sample', s=7)



Consider the following function. It will bin the data in equal bins along the axis and

  • if there are one or two points in a bin, take over those points,
  • if there are more points in a bin, take over the minimum and maximum value.
  • append the first and last point to make sure the same data range is used.

This allows to keep the original data in regions of low density, but significantly reduce the amount of data to plot in regions of high density. At the same time all the features are preserved with a sufficiently dense binning.

import numpy as np; np.random.seed(42)

def filt(x,y, bins):
    d = np.digitize(x, bins)
    xfilt = []
    yfilt = []
    for i in np.unique(d):
        xi = x[d == i]
        yi = y[d == i]
        if len(xi) <= 2:
            xfilt.extend([xi[np.argmax(yi)], xi[np.argmin(yi)]])
            yfilt.extend([yi.max(), yi.min()])
    # prepend/append first/last point if necessary
    if x[0] != xfilt[0]:
        xfilt = [x[0]] + xfilt
        yfilt = [y[0]] + yfilt
    if x[-1] != xfilt[-1]:
    sort = np.argsort(xfilt)
    return np.array(xfilt)[sort], np.array(yfilt)[sort]

To illustrate the concept let's use some toy data

x = np.array([1,2,3,4, 6,7,8,9, 11,14, 17, 26,28,29])
y = np.array([4,2,5,3, 7,3,5,5, 2, 4,  5,  2,5,3])
bins = np.linspace(0,30,7)

Then calling xf, yf = filt(x,y,bins) and plotting both the original data and the filtered data gives:

The usecase of the question with some 30000 datapoints would be shown in the following. Using the presented technique would allow to reduce the number of plotted points from 30000 to some 500. This number will of course depend on the binning in use - here 300 bins. In this case the function takes ~10 ms to compute. This is not super-fast, but still a large improvement compared to plotting all the points.

import matplotlib.pyplot as plt

# Generate some data
x = np.sort(np.random.rayleigh(3, size=30000))
y = np.cumsum(np.random.randn(len(x)))+250
# Decide for a number of bins
bins = np.linspace(x.min(),x.max(),301)
# Filter data
xf, yf = filt(x,y,bins) 

# Plot results
fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, figsize=(7,8), 

ax1.hist(x, bins=bins)

ax2.plot(x,y, linewidth=1, label="original data, {} points".format(len(x)))

ax3.plot(xf, yf, linewidth=1, label="binned min/max, {} points".format(len(xf)))

for ax in [ax2, ax3]: