Filter rows of a numpy array?

后端 未结 1 1053
别那么骄傲
别那么骄傲 2020-12-15 03:40

I am looking to apply a function to each row of a numpy array. If this function evaluates to true I will keep the row, otherwise I will discard it. For example, my function

相关标签:
1条回答
  • 2020-12-15 04:06

    Ideally, you would be able to implement a vectorized version of your function and use that to do boolean indexing. For the vast majority of problems this is the right solution. Numpy provides quite a few functions that can act over various axes as well as all the basic operations and comparisons, so most useful conditions should be vectorizable.

    import numpy as np
    
    x = np.random.randn(20, 3)
    x_new = x[np.sum(x, axis=1) > .5]
    

    If you are absolutely sure that you can't do the above, I would suggest using a list comprehension (or np.apply_along_axis) to create an array of bools to index with.

    def myfunc(row):
        return sum(row) > .5
    
    bool_arr = np.array([myfunc(row) for row in x])
    x_new = x[bool_arr]
    

    This will get the job done in a relatively clean way, but will be significantly slower than a vectorized version. An example:

    x = np.random.randn(5000, 200)
    
    %timeit x[np.sum(x, axis=1) > .5]
    # 100 loops, best of 3: 5.71 ms per loop
    
    %timeit x[np.array([myfunc(row) for row in x])]
    # 1 loops, best of 3: 217 ms per loop
    
    0 讨论(0)
提交回复
热议问题