Printing the output rounded to 3 decimals in SymPy

北慕城南 提交于 2019-12-11 02:19:33

问题


I got a SymPy matrix M

In [1]: from sympy import *
In [2]: M = 1/10**6 * Matrix(([1, 10, 100], [1000, 10000, 100000]))
In [3]: M
Out[3]: 
Matrix([
[1.0e-6, 1.0e-5, 0.0001],
[ 0.001,   0.01,    0.1]])

I want to print the output rounded to 3 decimals, as follows:

In [3]: M
Out[3]: 
Matrix([
[ 0.000, 0.000, 0.000],
[ 0.001, 0.010, 0.100]])

In normal Python, I would go like this:

In [5]: '{:.3f}'.format(1/10**6)
Out[5]: '0.000'

But how to do in a SymPy matrix?

Moreover, a more general case would be an expression containing also Symbols

x = symbols('x')
M = 1/10**6 * Matrix(([1, 10*x, 100], [1000, 10000*x, 100000]))

The preferred output is

In [3]: M
Out[3]: 
Matrix([
[ 0.000, 0.000*x, 0.000],
[ 0.001, 0.010*x, 0.100]])

回答1:


To round every number in an expression, use the following function

def round_expr(expr, num_digits):
    return expr.xreplace({n : round(n, num_digits) for n in expr.atoms(Number)})

It can be applied to any SymPy expression, including matrices.

x = symbols('x')
M = 1/10**6 * Matrix(([1, 10*x, 100], [1000, 10000*x, 100000]))
round_expr(M, 3)

yields

Matrix([
[  0.0,      0, 0.0],
[0.001, 0.01*x, 0.1]])

For numeric matrices, the following is simpler and preserves trailing zeros:

>>> M.applyfunc(lambda x: Symbol('{:.3f}'.format(x)))
Matrix([
[0.000, 0.000, 0.000],
[0.001, 0.010, 0.100]])

Here, applyfunc applies the given function to each entry of M. A natural thing to try would be lambda x: '{:.3f}'.format(x) but SymPy matrices aren't really meant to hold strings: the strings get parsed back into numbers and the trailing zeros get dropped, resulting in

Matrix([
[  0.0,  0.0, 0.0],
[0.001, 0.01, 0.1]])

So I wrap each string in Symbol, thus making a new symbol out of it. Symbols can appear in matrices, and they printed as their names, in this case the names are "0.000" and so on.




回答2:


A little bit old, but I used some time to find this today and needed some adjustment. My solution to the exact same question is as a previous answer (user6655984) but with evalf():

def printM(expr, num_digits):
    return expr.xreplace({n.evalf() : round(n, num_digits) for n in expr.atoms(Number)})

Then e.g. sqrt(13) is still sqrt(13) and not a numeric value. (Otherwise it is converted to a numeric value after the round and thus having more digits than num_digits. At least in Sympy 1.3 at Python 3.7.1)



来源:https://stackoverflow.com/questions/48491577/printing-the-output-rounded-to-3-decimals-in-sympy

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!