问题
With N 1-dimensional data X, I would like to evaluate each point at K cubic B-splines. In R, there is a simple function with an intuitive API, called bs. There is actually a python package patsy
which replicates this, but I can't use that package -- only scipy and such.
Having looked through the scipy.interpolate documentation on spline-related functions, the closest I can find is BSpline, or BSpline.basis_element, but how to get just the K basis functions is totally mysterious to me. I tried the following:
import numpy as np
import scipy.interpolate as intrp
import matplotlib.pyplot as plt
import patsy # for comparison
# in Patsy/R: nice and sensible
x = np.linspace(0., 1., 100)
y = patsy.bs(x, knots=np.linspace(0,1,4), degree=3)
plt.subplot(1,2,1)
plt.plot(x,y)
plt.title('B-spline basis')
# in scipy: ?????
y_py = np.zeros((x.shape[0], 6))
for i in range(6):
y_py[:,i] = intrp.BSpline(np.linspace(0,1,10),(np.arange(6)==i).astype(float), 3, extrapolate=False)(x)
plt.subplot(1,2,2)
plt.plot(x,y_py)
plt.title('Something else')
It doesn't work, and makes me realise I don't actually know what this function is doing. First of all, it will not accept fewer than 8 interior knots, which I don't understand why. Secondly, it only thinks that the splines are defined within (1/3, 2/3)ish range, which maybe means that it is ignoring the first 3 and last 3 knot values for some reason? Do I need to pad the knots?
Any help would be appreciated!
EDIT: I have solved this discrepancy, indeed it seems like BSpline ignore the first 3 and last 3 values of knots. I'm still interested in knowing why there is this discrepancy, so that I feel less bad for the odd hour spent debugging a strange interface.
For posterity, here is the code that does produce the basis functions
import numpy as np
import scipy.interpolate as intrp
import matplotlib.pyplot as plt
import patsy # for comparison
these_knots = np.linspace(0,1,5)
# in Patsy/R: nice and sensible
x = np.linspace(0., 1., 100)
y = patsy.bs(x, knots=these_knots, degree=3)
plt.subplot(1,2,1)
plt.plot(x,y)
plt.title('B-spline basis')
# in scipy: ?????
numpyknots = np.concatenate(([0,0,0],these_knots,[1,1,1])) # because??
y_py = np.zeros((x.shape[0], len(these_knots)+2))
for i in range(len(these_knots)+2):
y_py[:,i] = intrp.BSpline(numpyknots, (np.arange(len(these_knots)+2)==i).astype(float), 3, extrapolate=False)(x)
plt.subplot(1,2,2)
plt.plot(x,y_py)
plt.title('In SciPy')
回答1:
Looks like you already found the answer, but to clarify why these you need to define the multiple knots at the edges, you can read the scipy docs. They are defined using the Cox-de Boor recursive formula. This formula starts with defining neighbouring support domains between the given knot points with a constant value of 1 (zeroth order). These are convoluted to acquire the higher order basis functions. Hence two domains make one first order basis function, three domains make one second order basis function and four domains (= 5 knot points) make one third order basis function that is supported within the range of these 5 knot points. If you want n basis functions of degree k = 3, you will need to have (n+k+1) knot points.
The minimum of 8 knots is such that n >= k + 1, which gives 2 * (k+1). The base interval t[k] ... t[n] in scipy is the only range where you can define full degree basis functions. To make sure that this base interval reaches the outer knot points, the two end knots are usually given a multiplicity of (k+1). Probably scipy only showed this base interval in your 'Something else' result.
Note that you can also get the basis functions using
y_py[:,i] = intrp.BSpline.basis_element(numpyknots[i:i+5], extrapolate=False)(x)
this also removes the difference at x = 1.
来源:https://stackoverflow.com/questions/61807542/generate-a-b-spline-basis-in-scipy-like-bs-in-r