Generate C code with Sympy. Replace Pow(x,2) by x*x

孤人 提交于 2021-01-27 16:00:41

问题


I am generating C code with sympy the using the Common Subexpression Elimination (CSE) routine and the ccode printer.

However, I would like to have powered expressions as (x*x) instead of pow(x,2).

Anyway to do this?

Example:

import sympy as sp
a= sp.MatrixSymbol('a',3,3)
b=sp.Matrix(a)*sp.Matrix(a)

res = sp.cse(b)

lines = []
     
for tmp in res[0]:
    lines.append(sp.ccode(tmp[1], tmp[0]))

for i,result in enumerate(res[1]):
    lines.append(sp.ccode(result,"result_%i"%i))

Will output:

x0[0] = a[0];
x0[1] = a[1];
x0[2] = a[2];
x0[3] = a[3];
x0[4] = a[4];
x0[5] = a[5];
x0[6] = a[6];
x0[7] = a[7];
x0[8] = a[8];
x1 = x0[0];
x2 = x0[1];
x3 = x0[3];
x4 = x2*x3;
x5 = x0[2];
x6 = x0[6];
x7 = x5*x6;
x8 = x0[4];
x9 = x0[7];
x10 = x0[5];
x11 = x0[8];
x12 = x10*x9;
result_0[0] = pow(x1, 2) + x4 + x7;
result_0[1] = x1*x2 + x2*x8 + x5*x9;
result_0[2] = x1*x5 + x10*x2 + x11*x5;
result_0[3] = x1*x3 + x10*x6 + x3*x8;
result_0[4] = x12 + x4 + pow(x8, 2);
result_0[5] = x10*x11 + x10*x8 + x3*x5;
result_0[6] = x1*x6 + x11*x6 + x3*x9;
result_0[7] = x11*x9 + x2*x6 + x8*x9;
result_0[8] = pow(x11, 2) + x12 + x7;

Best Regards


回答1:


There is a function called create_expand_pow_optimization that creates a wrapper to optimise your expressions in this respect. It takes as an argument the highest power it will replace by explicit multiplications.

The wrapper returns an UnevaluatedExpr that is protected against automatic simplifications that would revert this change.

import sympy as sp
from sympy.codegen.rewriting import create_expand_pow_optimization

expand_opt = create_expand_pow_optimization(3)

a = sp.Matrix(sp.MatrixSymbol('a',3,3))
res = sp.cse(a@a)

for i,result in enumerate(res[1]):
    print(sp.ccode(expand_opt(result),"result_%i"%i))

Finally, be aware that for sufficiently high optimisation levels, your compiler will take care of this (and is probably better at this).




回答2:


You can subclass the code printer, and only change the one function you want different. You'd need to investigate the original sympy code to find the correct function names and default implementation, so you can make sure you don't make errors. With a bit of care, the needed brackets can be generated automatically exact when and where they are needed.

Here is a minimal example:

import sympy as sp
from sympy.printing.c import C99CodePrinter
from sympy.printing.precedence import precedence
from sympy.abc import x

class CustomCodePrinter(C99CodePrinter):
    def _print_Pow(self, expr):
        PREC = precedence(expr)
        if expr.exp == 2:
            return '({0} * {0})'.format(self.parenthesize(expr.base, PREC))
        else:
            return super()._print_Pow(expr)

default_printer = C99CodePrinter().doprint
custom_printer = CustomCodePrinter().doprint

expressions = [x, (2 + x) ** 2, x ** 3, x ** 15, sp.sqrt(5), sp.sqrt(x)**4, 1 / x, 1 / (x * x)]
print("Default: {}".format(default_printer(expressions)))
print("Custom: {}".format(custom_printer(expressions)))

Output:

Default: [x, pow(x + 2, 2), pow(x, 3), pow(x, 15), sqrt(5), pow(x, 2), 1.0/x, pow(x, -2)]
Custom: [x, ((x + 2) * (x + 2)), pow(x, 3), pow(x, 15), sqrt(5), (x * x), 1.0/x, pow(x, -2)]

PS: To support a wider range of exponents, you could use e.g.

class CustomCodePrinter(C99CodePrinter):
    def _print_Pow(self, expr):
        PREC = precedence(expr)
        if expr.exp in range(2, 7):
            return '*'.join([self.parenthesize(expr.base, PREC)] * int(expr.exp))
        elif expr.exp in range(-6, 0):
            return '1.0/(' + ('*'.join([self.parenthesize(expr.base, PREC)] * int(-expr.exp))) + ')'
        else:
            return super()._print_Pow(expr)



回答3:


I think I will go with the user_function approach:

As suggested in the comment above I will be using the user_functions functionality of sp.ccode: Assuming we have a number like a^3

sp.ccode(a**3, user_functions={'Pow': [(lambda x, y: y.is_integer, lambda x, y: '*'.join(['('+x+')']*int(y))),(lambda x, y: not y.is_integer, 'pow')]})

should output: '(a)*(a)*(a)'

In the future, I will try to improve the function to only put parenthesis when needed.

Any improvements are welcome!



来源:https://stackoverflow.com/questions/65534432/generate-c-code-with-sympy-replace-powx-2-by-xx

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!