How can I decorate all inherited methods in a subclass

别说谁变了你拦得住时间么 提交于 2021-02-08 06:10:49


class Reader:
    def __init__(self):

    def fetch_page(self):
        with open('/dev/blockingdevice/mypage.txt') as f:

    def fetch_another_page(self):
        with open('/dev/blockingdevice/another_mypage.txt') as f:

class Wrapper(Reader):
    def __init__(self):

    def sanity_check(func):
        def wrapper():
            txt = func()
            if 'banned_word' in txt:
                raise Exception('Device has banned word on it!')
        return wrapper

    <how to automatically put this decorator on each function of base class? >

w = Wrapper()

How can I make sure that sanity_check's wrapper was run automatically when calling fetch_page and fetch_another_page on an instance of the Wrapper class?


If using python3.6 or above, you can accomplish this using __init_subclass__

Simple implementation: (for the real thing you probably want a registry and functools.wraps, etc):

class Reader:
    def __init_subclass__(cls):
        cls.fetch_page = cls.sanity_check(cls.fetch_page)
        cls.fetch_another_page = cls.sanity_check(cls.fetch_another_page)

    def fetch_page(self):
        return 'banned_word'

    def fetch_another_page(self):
        return 'not a banned word'

class Wrapper(Reader):
    def sanity_check(func):
        def wrapper(*args, **kw):
            txt = func(*args, **kw)
            if 'banned_word' in txt:
                raise Exception('Device has banned word on it!')
            return txt
        return wrapper


In [55]: w = Wrapper()

In [56]: w.fetch_another_page()
Out[56]: 'not a banned word'

In [57]: w.fetch_page()
Exception                                 Traceback (most recent call last)
<ipython-input-57-4bb80bcb068e> in <module>()
----> 1 w.fetch_page()

Exception: Device has banned word on it!

Edit:In case you can't change the baseclass, you can subclass and create an Adapter class:

class Reader:

    def fetch_page(self):
        return 'banned_word'

    def fetch_another_page(self):
        return 'not a banned word'

class ReadAdapter(Reader):
    def __init_subclass__(cls):
        cls.fetch_page = cls.sanity_check(cls.fetch_page)
        cls.fetch_another_page = cls.sanity_check(cls.fetch_another_page)

class Wrapper(ReadAdapter):
    def sanity_check(func):
        def wrapper(*args, **kw):
            txt = func(*args, **kw)
            if 'banned_word' in txt:
                raise Exception('Device has banned word on it!')
            return txt
        return wrapper

Should provide the same result.


There's no easy way to do what you want from within the Wrapper subclass. You either need to name each method of the base class that you want to wrap with a decorator, modify the Wrapper class after you create it (perhaps with a class decorator), or you need to redesign the base class to help you out.

One relatively simple redesign would be for the base class methods to be decorated with a decorator that makes them always call a "validator" method. In the base class the validator can be a no-op, but a child class could override it to do whatever you want:

class Base:
    def sanity_check(func):
        def wrapper(self, *args, **kwargs):
            return self.validator(func(self, *args, **kwargs))
        return wrapper

    def validator(self, results):   # this validator accepts everything
        return results

    def foo(self):
        return "foo"

    def bar(self):
        return "bar"

class Derived(Base):
    def validator(self, results):   # this one doesn't like "bar"
        if results == "bar":
            raise Exception("I don't like bar")
        return results

obj = Derived() # works # fails to validate


Here is my solution for this:

class SubClass(Base):
    def __init__(self, *args, **argv):
        super().__init__(*args, **argv)

        for attr_name in Base.__dict__:
            attr = getattr(self, attr_name)
            if callable(attr):
                setattr(self, attr_name, functools.partial(__class__.sanity_check, attr))

    def sanity_check(func):
        txt = func()
        if 'banned_word' in txt:
            raise Exception('Device has banned word on it!')
        return txt

This will only work if you want to process each function in your Base with sanity_check.

