Converting a python numeric expression to LaTeX

前端 未结 5 490
感动是毒
感动是毒 2020-12-08 03:00

I need to convert strings with valid python syntax such as:

\'1+2**(x+y)\'

and get the equivalent LaTeX:

$1+2^{x+y}$
         


        
5条回答
  •  渐次进展
    2020-12-08 03:36

    Here's a rather long but still incomplete method that doesn't involve sympy in any way. It's enough to cover the example of (-b-sqrt(b**2-4*a*c))/(2*a) which gets translated to \frac{- b - \sqrt{b^{2} - 4 \; a \; c}}{2 \; a} and renders as

    alt text

    It basically creates the AST and walks it producing the latex math the corresponds to the AST nodes. What's there should give enough of an idea how to extend it in the places it's lacking.

    
    import ast
    
    class LatexVisitor(ast.NodeVisitor):
    
        def prec(self, n):
            return getattr(self, 'prec_'+n.__class__.__name__, getattr(self, 'generic_prec'))(n)
    
        def visit_Call(self, n):
            func = self.visit(n.func)
            args = ', '.join(map(self.visit, n.args))
            if func == 'sqrt':
                return '\sqrt{%s}' % args
            else:
                return r'\operatorname{%s}\left(%s\right)' % (func, args)
    
        def prec_Call(self, n):
            return 1000
    
        def visit_Name(self, n):
            return n.id
    
        def prec_Name(self, n):
            return 1000
    
        def visit_UnaryOp(self, n):
            if self.prec(n.op) > self.prec(n.operand):
                return r'%s \left(%s\right)' % (self.visit(n.op), self.visit(n.operand))
            else:
                return r'%s %s' % (self.visit(n.op), self.visit(n.operand))
    
        def prec_UnaryOp(self, n):
            return self.prec(n.op)
    
        def visit_BinOp(self, n):
            if self.prec(n.op) > self.prec(n.left):
                left = r'\left(%s\right)' % self.visit(n.left)
            else:
                left = self.visit(n.left)
            if self.prec(n.op) > self.prec(n.right):
                right = r'\left(%s\right)' % self.visit(n.right)
            else:
                right = self.visit(n.right)
            if isinstance(n.op, ast.Div):
                return r'\frac{%s}{%s}' % (self.visit(n.left), self.visit(n.right))
            elif isinstance(n.op, ast.FloorDiv):
                return r'\left\lfloor\frac{%s}{%s}\right\rfloor' % (self.visit(n.left), self.visit(n.right))
            elif isinstance(n.op, ast.Pow):
                return r'%s^{%s}' % (left, self.visit(n.right))
            else:
                return r'%s %s %s' % (left, self.visit(n.op), right)
    
        def prec_BinOp(self, n):
            return self.prec(n.op)
    
        def visit_Sub(self, n):
            return '-'
    
        def prec_Sub(self, n):
            return 300
    
        def visit_Add(self, n):
            return '+'
    
        def prec_Add(self, n):
            return 300
    
        def visit_Mult(self, n):
            return '\\;'
    
        def prec_Mult(self, n):
            return 400
    
        def visit_Mod(self, n):
            return '\\bmod'
    
        def prec_Mod(self, n):
            return 500
    
        def prec_Pow(self, n):
            return 700
    
        def prec_Div(self, n):
            return 400
    
        def prec_FloorDiv(self, n):
            return 400
    
        def visit_LShift(self, n):
            return '\\operatorname{shiftLeft}'
    
        def visit_RShift(self, n):
            return '\\operatorname{shiftRight}'
    
        def visit_BitOr(self, n):
            return '\\operatorname{or}'
    
        def visit_BitXor(self, n):
            return '\\operatorname{xor}'
    
        def visit_BitAnd(self, n):
            return '\\operatorname{and}'
    
        def visit_Invert(self, n):
            return '\\operatorname{invert}'
    
        def prec_Invert(self, n):
            return 800
    
        def visit_Not(self, n):
            return '\\neg'
    
        def prec_Not(self, n):
            return 800
    
        def visit_UAdd(self, n):
            return '+'
    
        def prec_UAdd(self, n):
            return 800
    
        def visit_USub(self, n):
            return '-'
    
        def prec_USub(self, n):
            return 800
        def visit_Num(self, n):
            return str(n.n)
    
        def prec_Num(self, n):
            return 1000
    
        def generic_visit(self, n):
            if isinstance(n, ast.AST):
                return r'' % (n.__class__.__name__, ', '.join(map(self.visit, [getattr(n, f) for f in n._fields])))
            else:
                return str(n)
    
        def generic_prec(self, n):
            return 0
    
    def py2tex(expr):
        pt = ast.parse(expr)
        return LatexVisitor().visit(pt.body[0].value)
    
    

提交回复
热议问题