Making Sieve of Eratosthenes more memory efficient in python?

╄→гoц情女王★ 提交于 2021-02-05 05:32:28

问题


Sieve of Eratosthenes memory constraint issue

Im currently trying to implement a version of the sieve of eratosthenes for a Kattis problem, however, I am running into some memory constraints that my implementation wont pass.

Here is a link to the problem statement. In short the problem wants me to first return the amount of primes less or equal to n and then solve for a certain number of queries if a number i is a prime or not. There is a constraint of 50 MB memory usage as well as only using the standard libraries of python (no numpy etc). The memory constraint is where I am stuck.

Here is my code so far:

import sys

def sieve_of_eratosthenes(xs, n):
    count = len(xs) + 1
    p = 3 # start at three
    index = 0
    while p*p < n:
        for i in range(index + p, len(xs), p):
            if xs[i]:
                xs[i] = 0
                count -= 1

        temp_index = index
        for i in range(index + 1, len(xs)):
            if xs[i]:
                p = xs[i]
                temp_index += 1
                break
            temp_index += 1
        index = temp_index

    return count


def isPrime(xs, a):
    if a == 1:
        return False
    if a == 2:
        return True
    if not (a & 1):
        return False
    return bool(xs[(a >> 1) - 1])

def main():
    n, q = map(int, sys.stdin.readline().split(' '))
    odds = [num for num in range(2, n+1) if (num & 1)]
    print(sieve_of_eratosthenes(odds, n))

    for _ in range(q):
        query = int(input())
        if isPrime(odds, query):
            print('1')
        else:
            print('0')


if __name__ == "__main__":
    main()

I've done some improvements so far, like only keeping a list of all odd numbers which halves the memory usage. I am also certain that the code works as intended when calculating the primes (not getting the wrong answer). My question is now, how can I make my code even more memory efficient? Should I use some other data structures? Replace my list of integers with booleans? Bitarray?

Any advice is much appreciated!

EDIT

After some tweaking to the code in python I hit a wall where my implementation of a segmented sieve would not pass the memory requirements.

Instead, I chose to implement the solution in Java, which took very little effort. Here is the code:

  public int sieveOfEratosthenes(int n){
    sieve = new BitSet((n+1) / 2);
    int count = (n + 1) / 2;

    for (int i=3; i*i <= n; i += 2){
      if (isComposite(i)) {
        continue;
      }

      // Increment by two, skipping all even numbers
      for (int c = i * i; c <= n; c += 2 * i){
        if(!isComposite(c)){
          setComposite(c);
          count--;
        }
      }
    }

    return count;

  }

  public boolean isComposite(int k) {
    return sieve.get((k - 3) / 2); // Since we don't keep track of even numbers
  }

  public void setComposite(int k) {
    sieve.set((k - 3) / 2); // Since we don't keep track of even numbers
  }

  public boolean isPrime(int a) {
    if (a < 3)
      return a > 1;

    if (a == 2)
      return true;

    if ((a & 1) == 1)
      return !isComposite(a);
    else
      return false;

  }

  public void run() throws Exception{
    BufferedReader scan = new BufferedReader(new InputStreamReader(System.in));
    String[] line = scan.readLine().split(" ");

    int n = Integer.parseInt(line[0]); int q = Integer.parseInt(line[1]);
    System.out.println(sieveOfEratosthenes(n));

    for (int i=0; i < q; i++){
      line = scan.readLine().split(" ");
      System.out.println( isPrime(Integer.parseInt(line[0])) ? '1' : '0');
    }
  }

I Have personally not found a way to implement this BitSet solution in Python (using only the standard library).

If anyone stumbles across a neat implementation to the problem in python, using a segmented sieve, bitarray or something else, I would be interested to see the solution.


回答1:


This is a very challenging problem indeed. With a maximum possible N of 10^8, using one byte per value results in almost 100 MB of data assuming no overhead whatsoever. Even halving the data by only storing odd numbers will put you very close to 50 MB after overhead is considered.

This means the solution will have to make use of one or more of a few strategies:

  1. Using a more efficient data type for our array of primality flags. Python lists maintain an array of pointers to each list item (4 bytes each on a 64 bit python). We effectively need raw binary storage, which pretty much only leaves bytearray in standard python.
  2. Using only one bit per value in the sieve instead of an entire byte (Bool technically only needs one bit, but typically uses a full byte).
  3. Sub-dividing to remove even numbers, and possibly also multiples of 3, 5, 7 etc.
  4. Using a segmented sieve

I initially tried to solve the problem by storing only 1 bit per value in the sieve, and while the memory usage was indeed within the requirements, Python's slow bit manipulation pushed the execution time far too long. It also was rather difficult to figure out the complex indexing to make sure the correct bits were being counted reliably.

I then implemented the odd numbers only solution using a bytearray and while it was quite a bit faster, the memory was still an issue.

Bytearray odd numbers implementation:

class Sieve:
    def __init__(self, n):
        self.not_prime = bytearray(n+1)
        self.not_prime[0] = self.not_prime[1] = 1
        for i in range(2, int(n**.5)+1):
            if self.not_prime[i] == 0:
                self.not_prime[i*i::i] = [1]*len(self.not_prime[i*i::i])
        self.n_prime = n + 1 - sum(self.not_prime)
        
    def is_prime(self, n):
        return int(not self.not_prime[n])
        


def main():
    n, q = map(int, input().split())
    s = Sieve(n)
    print(s.n_prime)
    for _ in range(q):
        i = int(input())
        print(s.is_prime(i))

if __name__ == "__main__":
    main()

Further reduction in memory from this should* make it work.

EDIT: also removing multiples of 2 and 3 did not seem to be enough memory reduction even though guppy.hpy().heap() seemed to suggest my usage was in fact a bit under 50MB. 🤷‍♂️




回答2:


There's a trick I learned just yesterday - if you divide the numbers into groups of 6, only 2 of the 6 may be prime. The others can be evenly divided by either 2 or 3. That means it only takes 2 bits to track the primality of 6 numbers; a byte containing 8 bits can track primality for 24 numbers! This greatly diminishes the memory requirements of your sieve.

In Python 3.7.5 64 bit on Windows 10, the following code didn't go over 36.4 MB.

remainder_bit = [0, 0x01, 0, 0, 0, 0x02,
                 0, 0x04, 0, 0, 0, 0x08,
                 0, 0x10, 0, 0, 0, 0x20,
                 0, 0x40, 0, 0, 0, 0x80]

def is_prime(xs, a):
    if a <= 3:
        return a > 1
    index, rem = divmod(a, 24)
    bit = remainder_bit[rem]
    if not bit:
        return False
    return not (xs[index] & bit)

def sieve_of_eratosthenes(xs, n):
    count = (n // 3) + 1 # subtract out 1 and 4, add 2 3 and 5
    p = 5
    while p*p <= n:
        if is_prime(xs, p):
            for i in range(5 * p, n + 1, p):
                index, rem = divmod(i, 24)
                bit = remainder_bit[rem]
                if bit and not (xs[index] & bit):
                    xs[index] |= bit
                    count -= 1
        p += 2
        if is_prime(xs, p):
            for i in range(5 * p, n + 1, p):
                index, rem = divmod(i, 24)
                bit = remainder_bit[rem]
                if bit and not (xs[index] & bit):
                    xs[index] |= bit
                    count -= 1
        p += 4

    return count


def init_sieve(n):
    return bytearray((n + 23) // 24)

n = 100000000
xs = init_sieve(n)
sieve_of_eratosthenes(xs, n)
5761455
sum(is_prime(xs, i) for i in range(n+1))
5761455

Edit: the key to understanding how this works is that a sieve creates a repeating pattern. For the primes 2 and 3 the pattern repeats every 2*3 or 6 numbers, and of those 6, 4 have been rendered impossible to be prime leaving only 2. There's nothing limiting you in the choices of prime numbers to generate the pattern, except perhaps for the law of diminishing returns. I decided to try adding 5 to the mix, making the pattern repeat every 2*3*5=30 numbers. Out of these 30 numbers only 8 can be prime, meaning each byte can track 30 numbers instead of the 24 above! That gives you a 20% advantage in memory usage.

Here's the updated code. I also simplified it a bit and took out the counting of primes as it went along.

remainder_bit30 = [0,    0x01, 0,    0,    0,    0,    0, 0x02, 0,    0,
                   0,    0x04, 0,    0x08, 0,    0,    0, 0x10, 0,    0x20,
                   0,    0,    0,    0x40, 0,    0,    0, 0,    0,    0x80]

def is_prime(xs, a):
    if a <= 5:
        return (a > 1) and (a != 4)
    index, rem = divmod(a, 30)
    bit = remainder_bit30[rem]
    return (bit != 0) and not (xs[index] & bit)

def sieve_of_eratosthenes(xs):
    n = 30 * len(xs) - 1
    p = 0
    while p*p < n:
        for offset in (1, 7, 11, 13, 17, 19, 23, 29):
            p += offset
            if is_prime(xs, p):
                for i in range(p * p, n + 1, p):
                    index, rem = divmod(i, 30)
                    if index < len(xs):
                        bit = remainder_bit30[rem]
                        xs[index] |= bit
            p -= offset
        p += 30

def init_sieve(n):
    b = bytearray((n + 30) // 30)
    return b



回答3:


I think you can try by using a list of booleans to mark whether its index is prime or not:

def sieve_of_erato(range_max):
    primes_count = range_max
    is_prime = [True for i in range(range_max + 1)]
    # Cross out all even numbers first.
    for i in range(4, range_max, 2):
        is_prime[i] = False
        primes_count -=1
    i = 3
    while i * i <= range_max:
        if is_prime[i]:
            # Update all multiples of this prime number
            # CAREFUL: Take note of the range args.
            # Reason for i += 2*i instead of i += i:
            # Since p and p*p, both are odd, (p*p + p) will be even,
            # which means that it would have already been marked before
            for multiple in range(i * i, range_max + 1, i * 2):
                is_prime[multiple] = False
                primes_count -= 1
        i += 1

    return primes_count


def main():
    num_primes = sieve_of_erato(100)
    print(num_primes)


if __name__ == "__main__":
    main()

You can use the is_prime array to check whether a number is prime or not later on by simply checking is_prime[number] == True.

If this doesn't work, then try segmented sieve.

As a bonus, you might be surprised to know that there is a way to generate the sieve in O(n) rather than O(nloglogn). Check the code here.




回答4:


Here is an example of a segmented sieve approach that should not exceed 8MB of memory.

def primeSieve(n,X,window=10**6): 
    primes     = []       # only store minimum number of primes to shift windows
    primeCount = 0        # count primes beyond the ones stored
    flags      = list(X)  # numbers will be replaced by 0 or 1 as we progress
    base       = 1        # number corresponding to 1st element of sieve
    isPrime    = [False]+[True]*(window-1) # starting sieve
    
    def flagPrimes(): # flag x values for current sieve window
        flags[:] = [isPrime[x-base]*1 if x in range(base,base+window) else x
                    for x in flags]
    for p in (2,*range(3,n+1,2)):       # potential primes: 2 and odd numbers
        if p >= base+window:            # shift sieve window as needed
            flagPrimes()                # set X flags before shifting window
            isPrime = [True]*window     # initialize next sieve window
            base    = p                 # 1st number in window
            for k in primes:            # update sieve using known primes 
                if k>base+window:break
                i = (k-base%k)%k + k*(k==p)  
                isPrime[i::k] = (False for _ in range(i,window,k))
        if not isPrime[p-base]: continue
        primeCount += 1                 # count primes 
        if p*p<=n:primes.append(p)      # store shifting primes, update sieve
        isPrime[p*p-base::p] = (False for _ in range(p*p-base,window,p))

    flagPrimes() # update flags with last window (should cover the rest of them)
    return primeCount,flags     
        

output:

print(*primeSieve(9973,[1,2,3,4,9972,9973]))
# 1229, [0, 1, 1, 0, 0, 1]

print(*primeSieve(10**8,[1,2,3,4,9972,9973,1000331]))
# 5761455 [0, 1, 1, 0, 0, 1, 0]

You can play with the window size to get the best trade off between execution time and memory consumption. The execution time (on my laptop) is still rather long for large values of n though:

from timeit import timeit
for w in range(3,9):
    t = timeit(lambda:primeSieve(10**8,[],10**w),number=1)
    print(f"10e{w} window:",t)

10e3 window: 119.463959956
10e4 window: 33.33273301199999
10e5 window: 24.153761258999992
10e6 window: 24.649398391000005
10e7 window: 27.616014667
10e8 window: 27.919413531000004

Strangely enough, window sizes beyond 10^6 give worse performance. The sweet spot seems to be somewhere between 10^5 and 10^6. A window of 10^7 would exceed your 50MB limit anyway.




回答5:


I had another idea on how to generate primes quickly in a memory efficient way. It is based on the same concept as the Sieve of Eratosthenes but uses a dictionary to hold the next value that each prime will invalidate (i.e. skip). This only requires storage of one dictionary entry per prime up to the square root of n.

def genPrimes(maxPrime):
    if maxPrime>=2: yield 2           # special processing for 2
    primeSkips = dict()               # skipValue:prime
    for n in range(3,maxPrime+1,2):
        if n not in primeSkips:       # if not in skip list, it is a new prime
            yield n
            if n*n <= maxPrime:       # first skip will be at n^2
                primeSkips[n*n] = n
            continue
        prime = primeSkips.pop(n)     # find next skip for n's prime
        skip  = n+2*prime
        while skip in primeSkips:     # must not already be skipped
            skip += 2*prime                
        if skip<=maxPrime:            # don't skip beyond maxPrime
            primeSkips[skip]=prime           

Using this, the primeSieve function can simply run through the prime numbers, count them, and flag the x values:

def primeSieve(n,X):
    primeCount = 0
    nonPrimes  = set(X)
    for prime in genPrimes(n):
        primeCount += 1
        nonPrimes.discard(prime)
    return primeCount,[0 if x in nonPrimes else 1 for x in X]


print(*primeSieve(9973,[1,2,3,4,9972,9973]))
# 1229, [0, 1, 1, 0, 0, 1]

print(*primeSieve(10**8,[1,2,3,4,9972,9973,1000331]))
# 5761455 [0, 1, 1, 0, 0, 1, 0]

This runs slightly faster than my previous answer and only consumes 78K of memory to generate primes up to 10^8 (in 21 seconds).



来源:https://stackoverflow.com/questions/62899578/making-sieve-of-eratosthenes-more-memory-efficient-in-python

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!