Generate a B-Spline basis in SciPy, like bs() in R

廉价感情. 提交于 2021-02-08 07:38:42

问题


With N 1-dimensional data X, I would like to evaluate each point at K cubic B-splines. In R, there is a simple function with an intuitive API, called bs. There is actually a python package patsy which replicates this, but I can't use that package -- only scipy and such.

Having looked through the scipy.interpolate documentation on spline-related functions, the closest I can find is BSpline, or BSpline.basis_element, but how to get just the K basis functions is totally mysterious to me. I tried the following:

import numpy as np
import scipy.interpolate as intrp
import matplotlib.pyplot as plt
import patsy # for comparison

# in Patsy/R: nice and sensible
x = np.linspace(0., 1., 100)
y = patsy.bs(x, knots=np.linspace(0,1,4), degree=3)
plt.subplot(1,2,1)
plt.plot(x,y)
plt.title('B-spline basis')

# in scipy: ?????
y_py = np.zeros((x.shape[0], 6))
for i in range(6):
    y_py[:,i] = intrp.BSpline(np.linspace(0,1,10),(np.arange(6)==i).astype(float), 3, extrapolate=False)(x)


plt.subplot(1,2,2)
plt.plot(x,y_py)
plt.title('Something else')


It doesn't work, and makes me realise I don't actually know what this function is doing. First of all, it will not accept fewer than 8 interior knots, which I don't understand why. Secondly, it only thinks that the splines are defined within (1/3, 2/3)ish range, which maybe means that it is ignoring the first 3 and last 3 knot values for some reason? Do I need to pad the knots?

Any help would be appreciated!

EDIT: I have solved this discrepancy, indeed it seems like BSpline ignore the first 3 and last 3 values of knots. I'm still interested in knowing why there is this discrepancy, so that I feel less bad for the odd hour spent debugging a strange interface.

For posterity, here is the code that does produce the basis functions

import numpy as np
import scipy.interpolate as intrp
import matplotlib.pyplot as plt
import patsy # for comparison

these_knots = np.linspace(0,1,5)

# in Patsy/R: nice and sensible
x = np.linspace(0., 1., 100)
y = patsy.bs(x, knots=these_knots, degree=3)
plt.subplot(1,2,1)
plt.plot(x,y)
plt.title('B-spline basis')

# in scipy: ?????
numpyknots = np.concatenate(([0,0,0],these_knots,[1,1,1])) # because??
y_py = np.zeros((x.shape[0], len(these_knots)+2))
for i in range(len(these_knots)+2):
    y_py[:,i] = intrp.BSpline(numpyknots, (np.arange(len(these_knots)+2)==i).astype(float), 3, extrapolate=False)(x)

plt.subplot(1,2,2)
plt.plot(x,y_py)
plt.title('In SciPy')



回答1:


Looks like you already found the answer, but to clarify why these you need to define the multiple knots at the edges, you can read the scipy docs. They are defined using the Cox-de Boor recursive formula. This formula starts with defining neighbouring support domains between the given knot points with a constant value of 1 (zeroth order). These are convoluted to acquire the higher order basis functions. Hence two domains make one first order basis function, three domains make one second order basis function and four domains (= 5 knot points) make one third order basis function that is supported within the range of these 5 knot points. If you want n basis functions of degree k = 3, you will need to have (n+k+1) knot points.

The minimum of 8 knots is such that n >= k + 1, which gives 2 * (k+1). The base interval t[k] ... t[n] in scipy is the only range where you can define full degree basis functions. To make sure that this base interval reaches the outer knot points, the two end knots are usually given a multiplicity of (k+1). Probably scipy only showed this base interval in your 'Something else' result.

Note that you can also get the basis functions using

y_py[:,i] = intrp.BSpline.basis_element(numpyknots[i:i+5], extrapolate=False)(x)

this also removes the difference at x = 1.



来源:https://stackoverflow.com/questions/61807542/generate-a-b-spline-basis-in-scipy-like-bs-in-r

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!