Dimensionality agnostic (generic) cartesian product

主宰稳场 提交于 2019-12-01 06:28:23

In plain Python, you can generate the Cartesian product of a collection of iterables using itertools.product.

>>> arrays = range(0, 2), range(4, 6), range(8, 10)
>>> list(itertools.product(*arrays))
[(0, 4, 8), (0, 4, 9), (0, 5, 8), (0, 5, 9), (1, 4, 8), (1, 4, 9), (1, 5, 8), (1, 5, 9)]

In Numpy, you can combine numpy.meshgrid (passing sparse=True to avoid expanding the product in memory) with numpy.ndindex:

>>> arrays = np.arange(0, 2), np.arange(4, 6), np.arange(8, 10)
>>> grid = np.meshgrid(*arrays, sparse=True)
>>> [tuple(g[i] for g in grid) for i in np.ndindex(grid[0].shape)]
[(0, 4, 8), (0, 4, 9), (1, 4, 8), (1, 4, 9), (0, 5, 8), (0, 5, 9), (1, 5, 8), (1, 5, 9)]

I think I figured out a nice way using a memory mapped file:

def carthesian_product_mmap(vectors, filename, mode='w+'):
    '''
    Vectors should be a tuple of `numpy.ndarray` vectors. You could
    also make it more flexible, and include some error checking
    '''        
    # Make a meshgrid with `copy=False` to create views
    grids = np.meshgrid(*vectors, copy=False, indexing='ij')

    # The shape for concatenating the grids from meshgrid
    shape = grid[0].shape + (len(vectors),)

    # Find the "highest" dtype neccesary
    dtype = np.result_type(*vectors)

    # Instantiate the memory mapped file
    M = np.memmap(filename, dtype, mode, shape=shape)

    # Fill the memmap with the grids
    for i, grid in enumerate(grids):
        M[...,i] = grid

    # Make sure the data is written to disk (optional?)
    M.flush()

    # Reshape to put it in the right format for Carthesian product
    return M.reshape((-1, len(vectors)))

But I wonder if you really need to store the whole Carthesian product (there's a lot of data duplication). Is it not an option to generate the rows in the product at the moment they're needed?

It seems you just want to loop over an arbitrary number of dimensions. My generic solution for this is using an index field and increment indices plus handling overflows.

Example:

n = 3 # number of dimensions
N = 1 # highest index value per dimension

idx = [0]*n
while True:
    print(idx)
    # increase first dimension
    idx[0] += 1
    # handle overflows
    for i in range(0, n-1):
        if idx[i] > N:
            # reset this dimension and increase next higher dimension
            idx[i] = 0
            idx[i+1] += 1
    if idx[-1] > N:
        # overflow in the last dimension, we are finished
        break

Gives:

[0, 0, 0]
[1, 0, 0]
[0, 1, 0]
[1, 1, 0]
[0, 0, 1]
[1, 0, 1]
[0, 1, 1]
[1, 1, 1]

Numpy has something similar inbuilt: ndenumerate.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!