Efficient cartesian product excluding items

给你一囗甜甜゛ 提交于 2021-02-10 05:36:06

问题


I'm am trying to get all possible combinations of 11 values repeated 80 times but filter out cases where the sum is above 1. The code below achieves what I'm trying to do but takes days to run:

import numpy as np
import itertools

unique_values = np.linspace(0.0, 1.0, 11)

lst = []
for p in itertools.product(unique_values , repeat=80):
    if sum(p)<=1:
        lst.append(p)

The solution above would work but needs way too much time. Also, in this case I would have to periodically save the 'lst' into the disk and free the memory in order to avoid any memory errors. The latter part is fine, but the code needs days (or maybe weeks) to complete.

Is there any alternative?


回答1:


Okay, this would be a bit more efficient, and you can use generator like this, and take your values as needed:

def get_solution(uniques, length, constraint):
    if length == 1:
        for u in uniques[uniques <= constraint + 1e-8]:
            yield u
    else:
        for u in uniques[uniques <= constraint + 1e-8]:
            for s in get_solution(uniques, length - 1, constraint - u):
                yield np.hstack((u, s))
g = get_solution(unique_values, 4, 1)
for _ in range(5):
    print(next(g))

prints

[0. 0. 0. 0.]
[0.  0.  0.  0.1]
[0.  0.  0.  0.2]
[0.  0.  0.  0.3]
[0.  0.  0.  0.4]

Comparing with your function:

def get_solution_product(uniques, length, constraint):
    return np.array([p for p in product(uniques, repeat=length) if np.sum(p) <= constraint + 1e-8])
%timeit np.vstack(list(get_solution(unique_values, 5, 1)))
346 ms ± 29.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit get_solution_product(unique_values, 5, 1)
2.94 s ± 256 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)



回答2:


OP simply needs the partitions of 10, but here's some general code I wrote in the meantime.

def find_combinations(values, max_total, repeat):
    if not (repeat and max_total > 0):
        yield ()
        return
    for v in values:
        if v <= max_total:
            for sub_comb in find_combinations(values, max_total - v, repeat - 1):
                yield (v,) + sub_comb


def main():
    all_combinations = find_combinations(range(1, 11), 10, 80)
    unique_combinations = {
        tuple(sorted(t))
        for t in all_combinations
    }
    for comb in sorted(unique_combinations):
        print(comb)

main()


来源:https://stackoverflow.com/questions/59748726/efficient-cartesian-product-excluding-items

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!