Iterate over 2d array in an expanding circular spiral

前端 未结 7 599
自闭症患者
自闭症患者 2020-12-23 14:39

Given an n by n matrix M, at row i and column j, I\'d like to iterate over all the neighboring values in a c

7条回答
  •  一生所求
    2020-12-23 15:20

    Well, I'm pretty embarrassed this is the best I have come up with so far. But maybe it will help you. Since it's not actually a circular iterator, I had to accept your test function as an argument.

    Problems:

    • not optimized to skip points outside the array
    • still uses a square iterator, but it does find the closest point
    • i haven't used numpy, so it's made for list of lists. the two points you need to change are commented
    • i left the square iterator in a long form so it's easier to read. it could be more DRY

    Here is the code. The key solution to your question is the top level "spiral_search" function which adds some extra logic on top of the square spiral iterator to make sure that the closest point is found.

    from math import sqrt
    
    #constants
    X = 0
    Y = 1
    
    def spiral_search(array, focus, test):
        """
        Search for the closest point to focus that satisfies test.
        test interface: test(point, focus, array)
        points structure: [x,y] (list, not tuple)
        returns tuple of best point [x,y] and the euclidean distance from focus
        """
        #stop if focus not in array
        if not _point_is_in_array(focus, array): raise IndexError("Focus must be within the array.")
        #starting closest radius and best point
        stop_radius = None
        best_point = None 
        for point in _square_spiral(array, focus):
            #cheap stop condition: when current point is outside the stop radius
            #(don't calculate outside axis where more expensive)
            if (stop_radius) and (point[Y] == 0) and (abs(point[X] - focus[X]) >= stop_radius):
                break #current best point is already as good or better so done
            #otherwise continue testing for closer solutions
            if test(point, focus, array):
                distance = _distance(focus, point)
                if (stop_radius == None) or (distance < stop_radius):
                    stop_radius = distance
                    best_point = point
        return best_point, stop_radius
    
    def _square_spiral(array, focus):
        yield focus
        size = len(array) * len(array[0]) #doesn't work for numpy
        count = 1
        r_square = 0
        offset = [0,0]
        rotation = 'clockwise'
        while count < size:
            r_square += 1
            #left
            dimension = X
            direction = -1
            for point in _travel_dimension(array, focus, offset, dimension, direction, r_square):
                yield point
                count += 1
            #up
            dimension = Y
            direction = 1
            for point in _travel_dimension(array, focus, offset, dimension, direction, r_square):
                yield point
                count += 1
            #right
            dimension = X
            direction = 1
            for point in _travel_dimension(array, focus, offset, dimension, direction, r_square):
                yield point
                count += 1
            #down
            dimension = Y
            direction = -1
            for point in _travel_dimension(array, focus, offset, dimension, direction, r_square):
                yield point
                count += 1
    
    def _travel_dimension(array, focus, offset, dimension, direction, r_square):
        for value in range(offset[dimension] + direction, direction*(1+r_square), direction):
            offset[dimension] = value
            point = _offset_to_point(offset, focus)
            if _point_is_in_array(point, array):
                yield point
    
    def _distance(focus, point):
        x2 = (point[X] - focus[X])**2
        y2 = (point[Y] - focus[Y])**2
        return sqrt(x2 + y2)
    
    def _offset_to_point(offset, focus):
        return [offset[X] + focus[X], offset[Y] + focus[Y]]
    
    def _point_is_in_array(point, array):
        if (0 <= point[X] < len(array)) and (0 <= point[Y] < len(array[0])): #doesn't work for numpy
            return True
        else:
            return False
    

提交回复
热议问题