Type hint for NumPy ndarray dtype?

后端 未结 4 1423
旧时难觅i
旧时难觅i 2020-12-10 03:44

I would like a function to include a type hint for NumPy ndarray\'s alongside with its dtype.

With lists, for example, one could do the fol

相关标签:
4条回答
  • 2020-12-10 03:49

    You could check out nptyping:

    from nptyping import NDArray, Bool
    
    def foo(bar: NDArray[Bool]):
       ...
    

    Or you could just use strings for type hints:

    def foo(bar: 'np.ndarray[np.bool]'):
       ...
    
    0 讨论(0)
  • 2020-12-10 03:59

    To the best of my knowledge it's not possible yet to specify dtype in numpy array type hints in function signatures. It is planned to be implemented at some point in the future. See numpy GitHub issue #7370 and numpy-stubs GitHub for more details on the current development status.

    0 讨论(0)
  • 2020-12-10 04:00

    One informal solution for type documentation is the following:

    from typing import TypeVar, Generic, Tuple, Union, Optional
    import numpy as np
    
    Shape = TypeVar("Shape")
    DType = TypeVar("DType")
    
    
    class Array(np.ndarray, Generic[Shape, DType]):
        """
        Use this to type-annotate numpy arrays, e.g.
    
            def transform_image(image: Array['H,W,3', np.uint8], ...):
                ...
    
        """
        pass
    
    
    def func(arr: Array['N,2', int]):
        return arr*2
    
    
    print(func(arr = np.array([(1, 2), (3, 4)])))
    
    

    We've been using this at my company and made a MyPy checker that actually checks that the shapes work out (which we should release at some point).

    Only thing is it doesn't make PyCharm happy (ie you still get the nasty warning lines):

    0 讨论(0)
  • 2020-12-10 04:04

    Check out data-science-types package.

    pip install data-science-types
    

    MyPy now has access to Numpy, Pandas, and Matplotlib stubs. Allows scenarios like:

    # program.py
    
    import numpy as np
    import pandas as pd
    
    arr1: np.ndarray[np.int64] = np.array([3, 7, 39, -3])  # OK
    arr2: np.ndarray[np.int32] = np.array([3, 7, 39, -3])  # Type error
    
    df: pd.DataFrame = pd.DataFrame({'col1': [1,2,3], 'col2': [4,5,6]}) # OK
    df1: pd.DataFrame = pd.Series([1,2,3]) # error: Incompatible types in assignment (expression has type "Series[int]", variable has type "DataFrame")
    

    Use mypy like normal.

    $ mypy program.py
    
    0 讨论(0)
提交回复
热议问题