问题
What is the best way to check if a field from a class is typing.Optional?
Example code:
from typing import Optional
import re
from dataclasses import dataclass, fields
@dataclass(frozen=True)
class TestClass:
required_field_1: str
required_field_2: int
optional_field: Optional[str]
def get_all_optional_fields(fields) -> list:
return [field.name for field in fields if __is_optional_field(field)]
def __is_optional_field(field) -> bool:
regex = '^typing.Union\[.*, NoneType\]$'
return re.match(regex, str(field.type)) is not None
print(get_all_optional_fields(fields(TestClass)))
Where fields
is from dataclasses
, I wanna list all the Optional
fields.
What I'm doing at this moment to solve it, is using a Regex-based on the field name, but I don't like this approach. Is there a better way of doing it?
回答1:
Optional[X]
is equivalent to Union[X, None]
. So you could do,
import re
from typing import Optional
from dataclasses import dataclass, fields
@dataclass(frozen=True)
class TestClass:
required_field_1: str
required_field_2: int
optional_field: Optional[str]
def get_optional_fields(klass):
class_fields = fields(klass)
for field in class_fields:
if (
hasattr(field.type, "__args__")
and len(field.type.__args__) == 2
and field.type.__args__[-1] is type(None)
):
# Check if exactly two arguments exists and one of them are None type
yield field.name
print(list(get_optional_fields(TestClass)))
回答2:
Note: typing.Optional[x]
is an alias for typing.Union[x, None]
Now, one could inspect the attributes of your input field annotation to check if it is defined like Union[x, None]:
You can read its attributes __module__
, __args__
and __origin__
:
from typing import *
def print_meta_info(x):
print(x.__module__, x.__args__, x.__origin__)
x = Optional[int]
print_meta_info(x) # 'typing', (class Int,), typing.Union
x = Union[int, float]
print_meta_info(x) # 'typing', (class int, class float), typing.Union
x = Iterable[str]
print_meta_info(x) # 'typing', (class int,), typing.Iterable
You need to take this steps to define your checker:
- Make sure that the annotation has the keys
__module__
,__args__
and__origin__
__module__
must be set to 'typing'. If not, the annotation is not an object defined by the typing module__origin__
value is equal to typing.Union__args__
must be a tuple with 2 items where the second one is the class NoneType (type(None)
)
If all conditions are evaluated to true, you have typing.Optional[x]
You may also need to know what is the optional class in the annotation:
x = Optional[int].__args__[0]
print(x) # class int
回答3:
I wrote a library called typedload which can be used to do this.
The main purpose of the library is conversion to/from json and namedtuple/dataclass/attrs, but since it needed to do those checks, it exposes the functions.
Note that different versions of python change how the internal typing API works, so checks will not work on every python version.
My library addresses it internally, hiding the details to the user.
Using it, the code is like this
from typing import *
a = Optional[int]
from typedload import typechecks
typechecks.is_union(a) and type(None) in typechecks.uniontypes(a)
https://github.com/ltworf/typedload
Of course, if you don't need to support multiple python versions, you might not care to depend on a library just for this, but future releases might break the check. They have changed API even between minor releases.
回答4:
For reference, Python 3.8 (first released October 2019) added get_origin
and get_args
functions to the typing
module.
Examples from the docs:
assert get_origin(Dict[str, int]) is dict
assert get_args(Dict[int, str]) == (int, str)
assert get_origin(Union[int, str]) is Union
assert get_args(Union[int, str]) == (int, str)
This will allow:
def is_optional(field):
return typing.get_origin(field) is Union and type(None) in typing.get_args(field)
Or with a fall-back option:
def is_optional(field):
if (sys.version_info.major, sys.version_info.minor) >= (3, 8):
return typing.get_origin(field.type) is typing.Union and \
type(None) in typing.get_args(field.type)
return getattr(field.type, '__origin__', None) is typing.Union and \
type(None) in getattr(field.type, '__args__', ())
来源:https://stackoverflow.com/questions/56832881/check-if-a-field-is-typing-optional