I need a working approach of getting all classes that are inherited from a base class in Python.
Note: I see that someone (not @unutbu) changed the referenced answer so that it no longer uses vars()['Foo'] — so the primary point of my post no longer applies.
FWIW, here's what I meant about @unutbu's answer only working with locally defined classes — and that using eval() instead of vars() would make it work with any accessible class, not only those defined in the current scope.
For those who dislike using eval(), a way is also shown to avoid it.
First here's a concrete example demonstrating the potential problem with using vars():
class Foo(object): pass
class Bar(Foo): pass
class Baz(Foo): pass
class Bing(Bar): pass
# unutbu's approach
def all_subclasses(cls):
return cls.__subclasses__() + [g for s in cls.__subclasses__()
for g in all_subclasses(s)]
print(all_subclasses(vars()['Foo'])) # Fine because Foo is in scope
# -> [, , ]
def func(): # won't work because Foo class is not locally defined
print(all_subclasses(vars()['Foo']))
try:
func() # not OK because Foo is not local to func()
except Exception as e:
print('calling func() raised exception: {!r}'.format(e))
# -> calling func() raised exception: KeyError('Foo',)
print(all_subclasses(eval('Foo'))) # OK
# -> [, , ]
# using eval('xxx') instead of vars()['xxx']
def func2():
print(all_subclasses(eval('Foo')))
func2() # Works
# -> [, , ]
This could be improved by moving the eval('ClassName') down into the function defined, which makes using it easier without loss of the additional generality gained by using eval() which unlike vars() is not context-sensitive:
# easier to use version
def all_subclasses2(classname):
direct_subclasses = eval(classname).__subclasses__()
return direct_subclasses + [g for s in direct_subclasses
for g in all_subclasses2(s.__name__)]
# pass 'xxx' instead of eval('xxx')
def func_ez():
print(all_subclasses2('Foo')) # simpler
func_ez()
# -> [, , ]
Lastly, it's possible, and perhaps even important in some cases, to avoid using eval() for security reasons, so here's a version without it:
def get_all_subclasses(cls):
""" Generator of all a class's subclasses. """
try:
for subclass in cls.__subclasses__():
yield subclass
for subclass in get_all_subclasses(subclass):
yield subclass
except TypeError:
return
def all_subclasses3(classname):
for cls in get_all_subclasses(object): # object is base of all new-style classes.
if cls.__name__.split('.')[-1] == classname:
break
else:
raise ValueError('class %s not found' % classname)
direct_subclasses = cls.__subclasses__()
return direct_subclasses + [g for s in direct_subclasses
for g in all_subclasses3(s.__name__)]
# no eval('xxx')
def func3():
print(all_subclasses3('Foo'))
func3() # Also works
# -> [, , ]