Check if a field is typing.Optional

試著忘記壹切 提交于 2020-05-09 05:03:26

问题


What is the best way to check if a field from a class is typing.Optional?

Example code:

from typing import Optional
import re
from dataclasses import dataclass, fields

@dataclass(frozen=True)
class TestClass:
    required_field_1: str
    required_field_2: int
    optional_field: Optional[str]

def get_all_optional_fields(fields) -> list:
    return [field.name for field in fields if __is_optional_field(field)]

def __is_optional_field(field) -> bool:
    regex = '^typing.Union\[.*, NoneType\]$'
    return re.match(regex, str(field.type)) is not None

print(get_all_optional_fields(fields(TestClass)))

Where fields is from dataclasses, I wanna list all the Optional fields. What I'm doing at this moment to solve it, is using a Regex-based on the field name, but I don't like this approach. Is there a better way of doing it?


回答1:


Optional[X] is equivalent to Union[X, None]. So you could do,

import re
from typing import Optional

from dataclasses import dataclass, fields


@dataclass(frozen=True)
class TestClass:
    required_field_1: str
    required_field_2: int
    optional_field: Optional[str]


def get_optional_fields(klass):
    class_fields = fields(klass)
    for field in class_fields:
        if (
            hasattr(field.type, "__args__")
            and len(field.type.__args__) == 2
            and field.type.__args__[-1] is type(None)
        ):
            # Check if exactly two arguments exists and one of them are None type
            yield field.name


print(list(get_optional_fields(TestClass)))



回答2:


Note: typing.Optional[x] is an alias for typing.Union[x, None]

Now, one could inspect the attributes of your input field annotation to check if it is defined like Union[x, None]:
You can read its attributes __module__, __args__ and __origin__:

from typing import *

def print_meta_info(x):
      print(x.__module__, x.__args__, x.__origin__)

x = Optional[int]
print_meta_info(x) # 'typing', (class Int,), typing.Union

x = Union[int, float]
print_meta_info(x) # 'typing', (class int, class float), typing.Union

x = Iterable[str]
print_meta_info(x) # 'typing', (class int,), typing.Iterable

You need to take this steps to define your checker:

  1. Make sure that the annotation has the keys __module__, __args__ and __origin__
  2. __module__ must be set to 'typing'. If not, the annotation is not an object defined by the typing module
  3. __origin__ value is equal to typing.Union
  4. __args__ must be a tuple with 2 items where the second one is the class NoneType (type(None))

If all conditions are evaluated to true, you have typing.Optional[x]

You may also need to know what is the optional class in the annotation:

x = Optional[int].__args__[0]
print(x) # class int



回答3:


I wrote a library called typedload which can be used to do this.

The main purpose of the library is conversion to/from json and namedtuple/dataclass/attrs, but since it needed to do those checks, it exposes the functions.

Note that different versions of python change how the internal typing API works, so checks will not work on every python version.

My library addresses it internally, hiding the details to the user.

Using it, the code is like this

from typing import *
a = Optional[int]

from typedload import typechecks
typechecks.is_union(a) and type(None) in typechecks.uniontypes(a)

https://github.com/ltworf/typedload

Of course, if you don't need to support multiple python versions, you might not care to depend on a library just for this, but future releases might break the check. They have changed API even between minor releases.




回答4:


For reference, Python 3.8 (first released October 2019) added get_origin and get_args functions to the typing module.

Examples from the docs:

assert get_origin(Dict[str, int]) is dict
assert get_args(Dict[int, str]) == (int, str)

assert get_origin(Union[int, str]) is Union
assert get_args(Union[int, str]) == (int, str)

This will allow:

def is_optional(field):
    return typing.get_origin(field) is Union and type(None) in typing.get_args(field)

Or with a fall-back option:

def is_optional(field):
    if (sys.version_info.major, sys.version_info.minor) >= (3, 8):
        return typing.get_origin(field.type) is typing.Union and \
            type(None) in typing.get_args(field.type)
    return getattr(field.type, '__origin__', None) is typing.Union and \
        type(None) in getattr(field.type, '__args__', ())


来源:https://stackoverflow.com/questions/56832881/check-if-a-field-is-typing-optional

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!