Optimizing mutable array state heavy manipulation code

末鹿安然 提交于 2019-12-06 10:49:56

One optimization you can do to mutable arrays is not to use them at all. In particular, the problem you have linked to has a right fold solution.

The idea being that you fold the list and greedily swap the items with the largest value to the right and maintain swaps already made in a Data.Map:

import qualified Data.Map as M
import Data.Map (empty, insert)

solve :: Int -> Int -> [Int] -> [Int]
solve n k xs = foldr go (\_ _ _ -> []) xs n empty k
    where
    go x run i m k
        -- out of budget to do a swap or no swap necessary
        | k == 0 || y == i = y : run (pred i) m k
        -- make a swap and record the swap made in the map
        | otherwise        = i : run (pred i) (insert i y m) (k - 1)
        where
        -- find the value current position is swapped with
        y = find x
        find k = case M.lookup k m of
            Just a  -> find a
            Nothing -> k

In above, run is a function which given the reverse index i, current mapping m and the remaining swap budget k, solves the rest of the list onwards. By reverse index I mean indices of the list in the reverse direction: n, n - 1, ..., 1.

The folding function go, builds the run function at each step by updating values of i, m and k which are passed to the next step. At the end we call this function with initial parameters i = n, m = empty and initial swap budget k.

The recursive search in find can be optimized out by maintaining a reverse map, but this already performs much faster than the java code you have posted.


Edit: Above solution, still pays a logarithmic cost for tree access. Here is an alternative solution using mutable STUArray and monadic fold foldM_, which in fact performs faster than above:

import Control.Monad.ST (ST)
import Control.Monad (foldM_)
import Data.Array.Unboxed (UArray, elems, listArray, array)
import Data.Array.ST (STUArray, readArray, writeArray, runSTUArray, thaw)

-- first 3 args are the scope, which will be curried
swap :: STUArray s Int Int -> STUArray s Int Int -> Int
     -> Int -> Int -> ST s Int
swap   _   _ _ 0 _ = return 0  -- out of budget to make a swap
swap arr rev n k i = do
    xi <- readArray arr i
    if xi + i == n + 1
    then return k -- no swap necessary
    else do -- make a swap, and reduce budget
        j <- readArray rev (n + 1 - i)
        writeArray rev xi j
        writeArray arr j  xi
        writeArray arr i (n + 1 - i)
        return $ pred k

solve :: Int -> Int -> [Int] -> [Int]
solve n k xs = elems $ runSTUArray $ do
    arr <- thaw (listArray (1, n) xs :: UArray Int Int)
    rev <- thaw (array (1, n) (zip xs [1..]) :: UArray Int Int)
    foldM_ (swap arr rev n) k [1..n]
    return arr

Not exactly an answer to #2, but there is a left fold solution that requires loading at most ~K values in memory at a time.

Because the problem deals with permutations, we know that 1 through N will appear in the output. If K > 0, at least the first K terms are going to be N, N-1, ... N - K, because we can afford at least K swaps. In addition, we expect some (K/N) digits to be in their optimal position.

This suggests an algorithm:

Initialize a map / dictionary and scan input xs as zip xs [n, n-1..]. For every (x, i), if x \= i, we 'decrement' K and update out dictionary s.t. dct[i] = x. This procedure terminates when K == 0 (out of swaps) or we run out of input (can output {N, N-1, ... 1}).

Next, if we have any more x <- xs we look at each one and print x if x is not in our dictionary or dct[x] otherwise.

The above algorithm can fail to produce an optimal permutation only if our dictionary contains a cycle. In that case, we moved around elements with absolute value >= K using |cycle| swaps. But this means that we moved one element to its original position! So we can always save a swap on every cycle (i.e. increment K).

Finally, this gives the memory efficient algorithm.

Step 0: get N, K

Step 1: Read the input permutation and output {N, N-1, ... N-K-E}, N <- N - K - E, K <- 0, update dict as per above,

where E = number of elements X equal to N - (index of X)

Step 2: remove and count cycles from dict; let cycles = number of cycles; if cycles > 0, let K <- |cycles|, go to step 1,

else go to step 3. We can make this step more efficient by optimizing the dict.

Step 3: Output the rest of the input as is.

The following Python code implements the idea and can be made quite fast if better cycle detection is used. Of course, data better be read in chunks, unlike below.

from collections import deque

n, t = map(int, raw_input().split())

xs = deque(map(int, raw_input().split()))

dct = {}

cycles = True
while cycles:
    while t > 0 and xs:
        x = xs.popleft()
        if x != n:
            dct[n] = x
            t -= 1
        print n,
        n -= 1

    cycles = False
    for k, v in dct.items():
        visited = set()
        cycle = False
        while v in dct:
            if v in visited:
                cycle = True
                break
            visited.add(v)
            v, buf = dct[v], v
            dct[buf] = v
        if cycle:
            cycles = True
            for i in visited:
                del dct[i]
            t += 1
        else:
            dct[k] = v

while xs:
    x = xs.popleft()
    print dct.get(x, x),
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!