How to extract the function model (polynomials) from scipy.interpolate.splprep()?

不问归期 提交于 2021-02-10 22:10:57

问题


I now have some discrete points, and I interpolated it using the scipy.interpolate.splprep () function (B-spline interpolation) to get a satisfactory smooth curve. Here's the code (draw on the answer to another question) and the result I got.

import numpy as np
from scipy import interpolate
from matplotlib import pyplot as plt

# x and y are points sampled randomly
x = sampledx
y = sampledy

# append the starting x,y coordinates
x = np.r_[x, x[0]]
y = np.r_[y, y[0]]

# fit splines to x=f(u) and y=g(u), treating both as periodic. also note that s=0
# is needed in order to force the spline fit to pass through all the input points.
tck, u = interpolate.splprep([x, y], s=0, per=True)

# evaluate the spline fits for 1000 evenly spaced distance values
xi, yi = interpolate.splev(np.linspace(0, 1, 1000), tck)

# plot the result
fig, ax = plt.subplots(figsize=(12, 12))
ax.plot(x, y, 'or')
ax.plot(xi, yi, '-b')

obtained curve

As far as I know, the function model obtained by cubic spline interpolation is a series of polynomials. Now I want to take out this function model, I try to print out the contents of tck.

[array([-0.30733587, -0.28200105, -0.22446703,  0.        ,  0.03802363,
         0.07911629,  0.09557235,  0.15790186,  0.20199024,  0.24140097,
         0.26977782,  0.31416052,  0.35118666,  0.42856196,  0.45166591,
         0.49503978,  0.51375395,  0.56799754,  0.59262884,  0.61845984,
         0.65603571,  0.69266413,  0.71799895,  0.77553297,  1.        ,
         1.03802363,  1.07911629,  1.09557235]),
 [array([229.12471144, -98.86968613,  50.15238681,  83.22909902,
          88.9466649 , 103.43169139, 158.24339347, 200.28605252,
         245.21725764, 291.11861604, 356.23057282, 404.75955996,
         429.18100345, 435.79417275, 430.58694659, 402.28422935,
         381.19094487, 360.28746542, 316.79933633, 271.50003508,
         242.72352701, 229.12471144, -98.86968613,  50.15238681]),
  array([-77.44508113, 184.01906954, 197.43235399, 226.25242057,
         275.95919475, 329.12264277, 360.20146464, 378.28519513,
         391.18454729, 390.47825093, 380.06668473, 339.92688063,
         285.65908782, 250.27639394, 201.82803336, 168.81117187,
         133.96870427,  94.65595445, 126.9811583 , 121.02433492,
          78.83626675, -77.44508113, 184.01906954, 197.43235399])],
 3]

After consulting the relevant documents, I learned that the first array is a list of knots, the second and third arrays are lists of coefficients, and the last single number is degree. If I got it right, the function model will be composed of 7 polynomials of which the max dimension of x is 3. How could I extract a function model (polynomials) based on these parameters ? Thanks a lot.


回答1:


The tck returned by interpolate.splprep consists 3 parts:

  • tck[0]: the 'knots' for the b-splines (this are values for the parameter u)
  • tck[1]: x and y coordinates of the relocated control points
  • tck[2]: the degree of the b-splines (3 for these cubic b-splines)

interpolate.splprep also outputs a list of u ticks. These are the values of u for which the b-spline is at each of the points-to-be-interpolated. These are marked with black lines on the colorbar.

A set of b-spline basis functions can be calculated depending on the knots. There will be one basis function for each control point (24 in your example).

To draw the curve, u needs to vary between 0 and 1. This is the np.linspace(0, 1, 1000) in your example code. For each of the u values, each pair of (basis-function(u), x-value) is multiplied together and the sum over all pairs is taken. The same happens for y.

Sympy's bspline_basis_set can be used to show how these functions look like.

Here is an example with just 4 points, as you'll notice the functions quickly become quite complex.

import numpy as np
from scipy import interpolate
from matplotlib import pyplot as plt

# x and y for a simple quadrangle
x = [0, 1, 40, 45]
y = [0, 22, 35, 7]

# append the starting x,y coordinates
x = np.r_[x, x[0]]
y = np.r_[y, y[0]]

# fit splines to x=f(u) and y=g(u), treating both as periodic. also note that s=0
# is needed in order to force the spline fit to pass through all the input points.
tck, u_ticks = interpolate.splprep([x, y], s=0, per=True)

# evaluate the spline fits for 1000 evenly spaced distance values
xi, yi = interpolate.splev(np.linspace(0, 1, 1000), tck)

# plot the result
fig, ax = plt.subplots(figsize=(12, 12))
ax.plot(x, y, 'Pk', ms=10, label='Points to interpolate')
ax.plot(xi, yi, '-b', lw=1, label='Interpolating spline (splev)', zorder=0)
ax.plot(tck[1][0], tck[1][1], 'om', ls=':', label='Calculated control points')

from sympy import lambdify, bspline_basis_set
from sympy.abc import u

basis = bspline_basis_set(tck[2], tck[0],  u)
for i, b in enumerate(basis):
    print(f"Basis {i} :", b)

# convert the basis functions to numpy so they can be evaluated quicker
np_basis = [lambdify(u, b, modules=['numpy']) for b in basis]

tck_x = tck[1][0]
tck_y = tck[1][1]

us = np.linspace(0, 1, 100)
xs = [sum([xi * bi(u_val) for xi, bi in zip(tck_x, np_basis)]) for u_val in us]
ys = [sum([yi * bi(u_val) for yi, bi in zip(tck_y, np_basis)]) for u_val in us]
plt.scatter(xs, ys, c=us, s=40, marker='o', cmap='tab10')
plt.legend()
cbar = plt.colorbar(label='u values')
for t in u_ticks:
    # mark the position of the u_ticks at the color bar
    cbar.ax.axhline(t, lw='3', color='black', clip_on=False)

plt.show()

Output:

Basis 0 : Piecewise((7.83358627878421*u**3 + 19.7262258572059*u**2 + 16.5579328428993*u + 4.63283654316489, (u >= -0.83938676170286) & (u <= -0.539571441177499)), (-34.7262442279844*u**3 - 49.1659813912158*u**2 - 20.6143347080305*u - 2.05286144826537, (u >= -0.539571441177499) & (u <= -0.332135154281002)), (23.3437491730212*u**3 + 8.69527726080352*u**2 - 1.39657663874914*u + 0.0747695654932114, (u >= -0.332135154281002) & (u <= 0)), (-18.0459953633398*u**3 + 8.69527726080352*u**2 - 1.39657663874914*u + 0.0747695654932114, (u >= 0) & (u <= 0.16061323829714)), (0, True))
Basis 1 : Piecewise((12.7600892248919*u**3 + 20.6549391978852*u**2 + 11.1448153104365*u + 2.00447468623643, (u >= -0.539571441177499) & (u <= -0.332135154281002)), (-24.4055001260175*u**3 - 16.3770570611408*u**2 - 1.15481248038858*u + 0.642761761601563, (u >= -0.332135154281002) & (u <= 0)), (51.0502963670014*u**3 - 16.3770570611408*u**2 - 1.15481248038858*u + 0.642761761601563, (u >= 0) & (u <= 0.16061323829714)), (-9.14007459775806*u**3 + 12.6250541237277*u**2 - 5.81293547524402*u + 0.892147167798265, (u >= 0.16061323829714) & (u <= 0.460428558822501)), (0, True))
Basis 2 : Piecewise((7.70949185527263*u**3 + 7.68177980033731*u**2 + 2.55138911913772*u + 0.282468672905225, (u >= -0.332135154281002) & (u <= 0)), (-53.251633917268*u**3 + 7.68177980033731*u**2 + 2.55138911913772*u + 0.282468672905225, (u >= 0) & (u <= 0.16061323829714)), (29.8321355272912*u**3 - 32.3512799809336*u**2 + 8.98122848955063*u - 0.0617704347655956, (u >= 0.16061323829714) & (u <= 0.460428558822501)), (-14.2299460617349*u**3 + 28.5110421933306*u**2 - 19.0415227957366*u + 4.2390545614098, (u >= 0.460428558822501) & (u <= 0.667864845718998)), (0, True))
Basis 3 : Piecewise((20.2473329136064*u**3, (u >= 0) & (u <= 0.16061323829714)), (-28.5256472083174*u**3 + 23.5007588363526*u**2 - 3.77453297914672*u + 0.202079988280036, (u >= 0.16061323829714) & (u <= 0.460428558822501)), (36.1961010648274*u**3 - 65.8984650092776*u**2 + 37.387422815947*u - 6.1153000067368, (u >= 0.460428558822501) & (u <= 0.667864845718998)), (-6.64774090227629*u**3 + 19.9432227068289*u**2 - 19.9432227068289*u + 6.64774090227629, (u >= 0.667864845718998) & (u <= 1.0)), (0, True))
Basis 4 : Piecewise((7.83358627878421*u**3 - 3.77453297914672*u**2 + 0.606239964840107*u - 0.0324567213127046, (u >= 0.16061323829714) & (u <= 0.460428558822501)), (-34.7262442279844*u**3 + 55.0127512927375*u**2 - 26.4611046095522*u + 4.1217360965338, (u >= 0.460428558822501) & (u <= 0.667864845718998)), (23.3437491730212*u**3 - 61.3359702582601*u**2 + 51.2441163587074*u - 13.1771257079753, (u >= 0.667864845718998) & (u <= 1.0)), (-18.0459953633398*u**3 + 62.8332633508229*u**2 - 72.9251172503755*u + 28.2126188283857, (u >= 1.0) & (u <= 1.16061323829714)), (0, True))
Basis 5 : Piecewise((12.7600892248919*u**3 - 17.6253284767905*u**2 + 8.11520458934184*u - 1.2454906512068, (u >= 0.460428558822501) & (u <= 0.667864845718998)), (-24.4055001260175*u**3 + 56.8394433169118*u**2 - 41.6171987361595*u + 9.82601730686685, (u >= 0.667864845718998) & (u <= 1.0)), (51.0502963670015*u**3 - 169.527946162145*u**2 + 184.750190742898*u - 65.6297791861522, (u >= 1.0) & (u <= 1.16061323829714)), (-9.14007459775806*u**3 + 40.0452779170019*u**2 - 58.4832675159736*u + 28.470211364528, (u >= 1.16061323829714) & (u <= 1.4604285588225)), (0, True))
Basis 6 : Piecewise((7.70949185527263*u**3 - 15.4466957654806*u**2 + 10.316305084281*u - 2.29663250116781, (u >= 0.667864845718998) & (u <= 1.0)), (-53.2516339172681*u**3 + 167.436681552142*u**2 - 172.567072233341*u + 58.6644932713729, (u >= 1.0) & (u <= 1.16061323829714)), (29.8321355272912*u**3 - 121.847686562807*u**2 + 163.180195033291*u - 71.226414432541, (u >= 1.16061323829714) & (u <= 1.4604285588225)), (-14.2299460617349*u**3 + 71.2008803785352*u**2 - 118.753445367602*u + 66.0215656122119, (u >= 1.4604285588225) & (u <= 1.667864845719)), (0, True))

Alternatively, as mentioned in this post, sympy has a not-yet-documented function interpolating_spline that calculates the piecewise functions combined with the x values. (Note that there 'x' is used where we use 'u', and 'y' where we use 'x'. This can be confusing things sometimes ...)

To get this to work with a circular list, 2 extra nodes need to be added at the front and two at the end. So, together with the repeated node added earlier, there are now 9 nodes to represent the 4 original points.

from sympy import interpolating_spline, lambdify
from sympy.abc import u

# ... the same code as above, but replacing the complete sympy part

# use the u_ticks from 
us = [u_ticks[-3] - 1, u_ticks[-2] - 1, *u_ticks, u_ticks[1] + 1, u_ticks[2] + 1]
xs = [*x[-3:-1], *x, * x[1:3]]
ys = [*y[-3:-1], *y, * y[1:3]]

interpx = interpolating_spline(tck[2], u, us, xs)
interpy = interpolating_spline(tck[2], u, us, ys)

print(interpx)
print(interpy)

fx = lambdify(u, interpx, modules=['numpy'])
fy = lambdify(u, interpy, modules=['numpy'])

us = np.linspace(0, 1, 100)
plt.scatter(fx(us), fy(us), c=us, s=40, marker='o', cmap='tab10') # label='sympy´s interpolating_spline'

As now the x's are already summed together, there is just one formula for the b-spline for x, and one for y:

# for x:
Piecewise((259.449085976667*u**3 + 332.098590899285*u**2 - 53.8062007647187*u - 8.88178419700125e-16, (u >= -0.332135154281002) & (u <= 0.16061323829714)), (-889.09792969929*u**3 + 885.514157471979*u**2 - 142.692067036006*u + 4.75874894022597, (u >= 0.16061323829714) & (u <= 0.460428558822501)), (-281.671950803575*u**3 + 46.4853533090758*u**2 + 243.620756075287*u - 54.5310698597021, (u >= 0.460428558822501) & (u <= 0.667864845718998)), (976.463184688985*u**3 - 2474.30733116909*u**2 + 1927.16957338388*u - 429.32542690377, (u >= 0.667864845718998) & (u <= 1.16061323829714)))
# for y:
Piecewise((-737.592577045201*u**3 + 194.240200950605*u**2 + 124.804852561614*u + 3.5527136788005e-15, (u >= -0.332135154281002) & (u <= 0.16061323829714)), (-427.62807998269*u**3 + 44.8869960595423*u**2 + 148.792954449223*u - 1.28426890825692, (u >= 0.16061323829714) & (u <= 0.460428558822501)), (1396.06082019756*u**3 - 2474.14836009222*u**2 + 1308.6287731051*u - 179.291447059738, (u >= 0.460428558822501) & (u <= 0.667864845718998)), (-2.71308577093816*u**3 + 328.427396624023*u**2 - 563.113052269992*u + 237.398741416907, (u >= 0.667864845718998) & (u <= 1.16061323829714)))



来源:https://stackoverflow.com/questions/60105444/how-to-extract-the-function-model-polynomials-from-scipy-interpolate-splprep

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!