Getting all the nodes from Python AST that correspond to a particular variable with a given name

别来无恙 提交于 2019-12-06 01:52:54

When walking the AST tree, track the context; start with a global context, then as you encounter FunctionDef or ClassDef or Lambda nodes, record that context as a stack (pop the stack again when exiting the relevant node).

You can then simply only look at Name nodes in the global context. You can track global identifiers too (I'd use a set per stack level).

Using a NodeVisitor subclass:

import ast

class GlobalUseCollector(ast.NodeVisitor):
    def __init__(self, name):
        self.name = name
        # track context name and set of names marked as `global`
        self.context = [('global', ())]

    def visit_FunctionDef(self, node):
        self.context.append(('function', set()))
        self.generic_visit(node)
        self.context.pop()

    # treat coroutines the same way
    visit_AsyncFunctionDef = visit_FunctionDef

    def visit_ClassDef(self, node):
        self.context.append(('class', ()))
        self.generic_visit(node)
        self.context.pop()

    def visit_Lambda(self, node):
        # lambdas are just functions, albeit with no statements, so no assignments
        self.context.append(('function', ()))
        self.generic_visit(node)
        self.context.pop()

    def visit_Global(self, node):
        assert self.context[-1][0] == 'function'
        self.context[-1][1].update(node.names)

    def visit_Name(self, node):
        ctx, g = self.context[-1]
        if node.id == self.name and (ctx == 'global' or node.id in g):
            print('{} used at line {}'.format(node.id, node.lineno))

Demo (given the AST tree for your sample code in t):

>>> GlobalUseCollector('x').visit(t)
x used at line 1
x used at line 10
x used at line 12
x used at line 13

And using global x in a function:

>>> u = ast.parse('''\
... x = 20
...
... def g():
...     global x
...     x = 0
...     for x in range(10):
...         x += 10
...     return x
...
... g()
... for x in range(10):
...     pass
... x += 1
... print(x)
... ''')
>>> GlobalUseCollector('x').visit(u)
x used at line 1
x used at line 5
x used at line 6
x used at line 7
x used at line 8
x used at line 11
x used at line 13
x used at line 14
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!