How to get source corresponding to a Python AST node?

前端 未结 3 1052
一生所求
一生所求 2020-12-15 22:22

Python AST nodes have lineno and col_offset attributes, which indicate the beginning of respective code range. Is there an easy way to get also the

相关标签:
3条回答
  • 2020-12-15 23:00

    We had a similar need, and I created the asttokens library for this purpose. It maintains the source in both text and tokenized form, and marks AST nodes with token information, from which text is also readily available.

    It works with Python 2 and 3 (tested with 2.7 and 3.5). For example:

    import ast, asttokens
    st='''
    def greet(a):
      say("hello") if a else say("bye")
    '''
    atok = asttokens.ASTTokens(st, parse=True)
    for node in ast.walk(atok.tree):
      if hasattr(node, 'lineno'):
        print atok.get_text_range(node), node.__class__.__name__, atok.get_text(node)
    

    Prints

    (1, 50) FunctionDef def greet(a):
      say("hello") if a else say("bye")
    (17, 50) Expr say("hello") if a else say("bye")
    (11, 12) Name a
    (17, 50) IfExp say("hello") if a else say("bye")
    (33, 34) Name a
    (17, 29) Call say("hello")
    (40, 50) Call say("bye")
    (17, 20) Name say
    (21, 28) Str "hello"
    (40, 43) Name say
    (44, 49) Str "bye"
    
    0 讨论(0)
  • 2020-12-15 23:13

    Hi I know its very late , But I think is this is what you are looking for, I am doing the parsing only for function definitions in the module. We can get the first and last line of the ast node by this method. This way the source code lines of a function definition can be obtained by parsing the source file by reading only the lines we need . This is a very simple example ,

    st='def foo():\n    print "hello" \n\ndef bla():\n    a = 1\n    b = 2\n  
    c= a+b\n    print c'
    
    import ast 
    tree = ast.parse(st)
    for function in tree.body:
        if isinstance(function,ast.FunctionDef):
            # Just in case if there are loops in the definition
            lastBody = func.body[-1]
            while isinstance (lastBody,(ast.For,ast.While,ast.If)):
                lastBody = lastBody.Body[-1]
            lastLine = lastBody.lineno
            print "Name of the function is ",function.name
            print "firstLine of the function is ",function.lineno
            print "LastLine of the function is ",lastLine
            print "the source lines are "
            if isinstance(st,str):
                st = st.split("\n")
            for i , line in enumerate(st,1):
                if i in range(function.lineno,lastLine+1):
                    print line
    
    0 讨论(0)
  • 2020-12-15 23:22

    EDIT: Latest code (tested in Python 3.5-3.7) is here: https://bitbucket.org/plas/thonny/src/master/thonny/ast_utils.py

    As I didn't find an easy way, here's a hard (and probably not optimal) way. Might crash and/or work incorrectly if there are more lineno/col_offset bugs in Python parser than those mentioned (and worked around) in the code. Tested in Python 3.3:

    def mark_code_ranges(node, source):
        """
        Node is an AST, source is corresponding source as string.
        Function adds recursively attributes end_lineno and end_col_offset to each node
        which has attributes lineno and col_offset.
        """
    
        NON_VALUE_KEYWORDS = set(keyword.kwlist) - {'False', 'True', 'None'}
    
    
        def _get_ordered_child_nodes(node):
            if isinstance(node, ast.Dict):
                children = []
                for i in range(len(node.keys)):
                    children.append(node.keys[i])
                    children.append(node.values[i])
                return children
            elif isinstance(node, ast.Call):
                children = [node.func] + node.args
    
                for kw in node.keywords:
                    children.append(kw.value)
    
                if node.starargs != None:
                    children.append(node.starargs)
                if node.kwargs != None:
                    children.append(node.kwargs)
    
                children.sort(key=lambda x: (x.lineno, x.col_offset))
                return children
            else:
                return ast.iter_child_nodes(node)    
    
        def _fix_triple_quote_positions(root, all_tokens):
            """
            http://bugs.python.org/issue18370
            """
            string_tokens = list(filter(lambda tok: tok.type == token.STRING, all_tokens))
    
            def _fix_str_nodes(node):
                if isinstance(node, ast.Str):
                    tok = string_tokens.pop(0)
                    node.lineno, node.col_offset = tok.start
    
                for child in _get_ordered_child_nodes(node):
                    _fix_str_nodes(child)
    
            _fix_str_nodes(root)
    
            # fix their erroneous Expr parents   
            for node in ast.walk(root):
                if ((isinstance(node, ast.Expr) or isinstance(node, ast.Attribute))
                    and isinstance(node.value, ast.Str)):
                    node.lineno, node.col_offset = node.value.lineno, node.value.col_offset
    
        def _fix_binop_positions(node):
            """
            http://bugs.python.org/issue18374
            """
            for child in ast.iter_child_nodes(node):
                _fix_binop_positions(child)
    
            if isinstance(node, ast.BinOp):
                node.lineno = node.left.lineno
                node.col_offset = node.left.col_offset
    
    
        def _extract_tokens(tokens, lineno, col_offset, end_lineno, end_col_offset):
            return list(filter((lambda tok: tok.start[0] >= lineno
                                       and (tok.start[1] >= col_offset or tok.start[0] > lineno)
                                       and tok.end[0] <= end_lineno
                                       and (tok.end[1] <= end_col_offset or tok.end[0] < end_lineno)
                                       and tok.string != ''),
                               tokens))
    
    
    
        def _mark_code_ranges_rec(node, tokens, prelim_end_lineno, prelim_end_col_offset):
            """
            Returns the earliest starting position found in given tree, 
            this is convenient for internal handling of the siblings
            """
    
            # set end markers to this node
            if "lineno" in node._attributes and "col_offset" in node._attributes:
                tokens = _extract_tokens(tokens, node.lineno, node.col_offset, prelim_end_lineno, prelim_end_col_offset)
                #tokens = 
                _set_real_end(node, tokens, prelim_end_lineno, prelim_end_col_offset)
    
            # mark its children, starting from last one
            # NB! need to sort children because eg. in dict literal all keys come first and then all values
            children = list(_get_ordered_child_nodes(node))
            for child in reversed(children):
                (prelim_end_lineno, prelim_end_col_offset) = \
                    _mark_code_ranges_rec(child, tokens, prelim_end_lineno, prelim_end_col_offset)
    
            if "lineno" in node._attributes and "col_offset" in node._attributes:
                # new "front" is beginning of this node
                prelim_end_lineno = node.lineno
                prelim_end_col_offset = node.col_offset
    
            return (prelim_end_lineno, prelim_end_col_offset)
    
        def _strip_trailing_junk_from_expressions(tokens):
            while (tokens[-1].type not in (token.RBRACE, token.RPAR, token.RSQB,
                                          token.NAME, token.NUMBER, token.STRING, 
                                          token.ELLIPSIS)
                        and tokens[-1].string not in ")}]"
                        or tokens[-1].string in NON_VALUE_KEYWORDS):
                del tokens[-1]
    
        def _strip_trailing_extra_closers(tokens, remove_naked_comma):
            level = 0
            for i in range(len(tokens)):
                if tokens[i].string in "({[":
                    level += 1
                elif tokens[i].string in ")}]":
                    level -= 1
    
                if level == 0 and tokens[i].string == "," and remove_naked_comma:
                    tokens[:] = tokens[0:i]
                    return
    
                if level < 0:
                    tokens[:] = tokens[0:i]
                    return   
    
        def _set_real_end(node, tokens, prelim_end_lineno, prelim_end_col_offset):
            # prelim_end_lineno and prelim_end_col_offset are the start of 
            # next positioned node or end of source, ie. the suffix of given
            # range may contain keywords, commas and other stuff not belonging to current node
    
            # Function returns the list of tokens which cover all its children
    
    
            if isinstance(node, _ast.stmt):
                # remove empty trailing lines
                while (tokens[-1].type in (tokenize.NL, tokenize.COMMENT, token.NEWLINE, token.INDENT)
                       or tokens[-1].string in (":", "else", "elif", "finally", "except")):
                    del tokens[-1]
    
            else:
                _strip_trailing_extra_closers(tokens, not isinstance(node, ast.Tuple))
                _strip_trailing_junk_from_expressions(tokens)
    
            # set the end markers of this node
            node.end_lineno = tokens[-1].end[0]
            node.end_col_offset = tokens[-1].end[1]
    
            # Try to peel off more tokens to give better estimate for children
            # Empty parens would confuse the children of no argument Call
            if ((isinstance(node, ast.Call)) 
                and not (node.args or node.keywords or node.starargs or node.kwargs)):
                assert tokens[-1].string == ')'
                del tokens[-1]
                _strip_trailing_junk_from_expressions(tokens)
            # attribute name would confuse the "value" of Attribute
            elif isinstance(node, ast.Attribute):
                if tokens[-1].type == token.NAME:
                    del tokens[-1]
                    _strip_trailing_junk_from_expressions(tokens)
                else:
                    raise AssertionError("Expected token.NAME, got " + str(tokens[-1]))
                    #import sys
                    #print("Expected token.NAME, got " + str(tokens[-1]), file=sys.stderr)
    
            return tokens
    
        all_tokens = list(tokenize.tokenize(io.BytesIO(source.encode('utf-8')).readline))
        _fix_triple_quote_positions(node, all_tokens)
        _fix_binop_positions(node)
        source_lines = source.split("\n") 
        prelim_end_lineno = len(source_lines)
        prelim_end_col_offset = len(source_lines[len(source_lines)-1])
        _mark_code_ranges_rec(node, all_tokens, prelim_end_lineno, prelim_end_col_offset)
    
    0 讨论(0)
提交回复
热议问题