How do I get SymPy to collect partial derivatives?

前提是你 提交于 2021-01-27 19:59:03

问题


I have been using SymPy to expand the terms of a complex partial differential equation and would like to use the collect function to gather terms. However, it seems to have a problem dealing with second (or higher order) derivatives where the variables of differentiation differ.

In the code example below collect(expr6... works, but collect(expr7 ... does not, returning the error message "NotImplementedError: Improve MV Derivative support in collect". The error is clearly related to the psi.diff(x,y) difference in the two cases. Is it obvious to anyone what I need to do to have collect(expr7 ... work?

cheers

Richard

Example:

from sympy import *

psi = Function("psi") (x,y,z,t)

expr6=2*psi.diff(x,x)+3*U*psi.diff(x)+5*psi.diff(y)
expr7=2*psi.diff(x,y)+3*U*psi.diff(x)+5*psi.diff(y)

collect(expr6, psi.diff(x),evaluate=False, exact=False)  # works
#collect(expr7, psi.diff(x),evaluate=False, exact=False)
   # throws an error: NotImplementedError: Improve MV Derivative support in collect

回答1:


I've bumped into this issue and my workaround is to perform a substitution with simple dummy variables first, collect based on these simple variables, and then substitute back the more advanced variables. There might be some corner cases, but it seems to work for me.

from sympy import symarray, collect
def mycollect(expr, var_list, evaluate=True, **kwargs):
    """ Acts as collect but substitute the symbols with dummy symbols first so that it can work with partial derivatives. 
        Matrix expressions are also supported. 
    """
    if not hasattr(var_list, '__len__'):
        var_list=[var_list]
    # Mapping Var -> Dummy, and Dummy-> Var
    Dummies=symarray('DUM', len(var_list))
    Var2Dummy=[(var, Dummies[i]) for i,var in enumerate(var_list)]
    Dummy2Var=[(b,a) for a,b in Var2Dummy]
    # Replace var with dummies and apply collect
    expr = expr.expand().doit()
    expr = expr.subs(Var2Dummy)
    if hasattr(expr, '__len__'):
        expr = expr.applyfunc(lambda ij: collect(ij, Dummies, **kwargs))
    else:
        expr = collect(expr, Dummies, evaluate=evaluate, **kwargs)
    # Substitute back
    if evaluate:
        return expr.subs(Dummy2Var)
    d={}
    for k,v in expr.items():
        k=k.subs(Dummy2Var)
        v=v.subs(Dummy2Var)
        d[k]=v
    return d

For your example:

mycollect(expr6, psi.diff(x), evaluate=False)
mycollect(expr7, psi.diff(x), evaluate=False)

returns:

{Derivative(psi(x, y, z, t), (x, 2)): 2, Derivative(psi(x, y, z, t), x): 3*U, 1: 5*Derivative(psi(x, y, z, t), y)}
{Derivative(psi(x, y, z, t), x, y): 2, Derivative(psi(x, y, z, t), x): 3*U, 1: 5*Derivative(psi(x, y, z, t), y)}


来源:https://stackoverflow.com/questions/58700443/how-do-i-get-sympy-to-collect-partial-derivatives

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!