问题
I have this problem where I need to find the number of sums of powers that are equal to a number. So for example:
An input of
100 2 would yield an output of 3 because 100 = 10^2 = 6^2 + 8^2 = 1^2 + 3^2 + 4^2 + 5^2 + 7^2 and an input of 100 3 would yield an output of 1 because 100 = 1^3 + 2^3 + 3^3 + 4^3
So my function for solving this problem is:
findNums :: Int -> Int -> Int
findNums a b = length [xs | xs <- (drop 1 .) subsequences [pow x b | x <- [1..c]], foldr (+) (head xs) (tail xs) == a] where c = root a b 0
root :: Int -> Int -> Int -> Int
root n a i
| pow i a <= n && pow (i+1) a > n = i
| otherwise = root n a (i+1)
pow :: Int -> Int -> Int
pow _ 0 = 1
pow n a = n * pow n (a - 1)
I find all the possible values that are able to be in my set of numbers that will add up to the desired number. Then I find all possible sublists inside that list and see how many of those add up to the desired number. This works but since it is an exhaustive search it takes a long time for inputs like 800 2. Is it possible to optimize the sequences such that only the "plausible" subsequences are returned? Or is it better to look at parallel computation for this sort of problem?
回答1:
Let's take a tour through a few things.
Benchmarking
First up: let's make sure we're actually making improvements as we go! For that, we'll need some benchmarks. The criterion package is perfect for this. We'll also make sure to compile with optimizations (so -O2 on all calls to GHC). Here's how simple setting up a benchmark can be:
import Criterion.Main
-- your code goes here
main = defaultMain
[ bench "findNums 100 2" (nf (uncurry findNums) (100, 2))
, bench "findNums 800 2" (nf (uncurry findNums) (800, 2))
]
One could also implement the benchmark as nf (findNums 100) 2, but I choose this way so that we can't "cheat" by precomputing a lookup table for 100, thus pushing all the work into the benchmark setup rather than the part where the benchmark is actually run. Here's the result for the original implementation:
benchmarking 100 2
time 762.7 ns (757.4 ns .. 768.5 ns)
1.000 R² (1.000 R² .. 1.000 R²)
mean 762.5 ns (760.4 ns .. 765.3 ns)
std dev 7.706 ns (6.378 ns .. 10.59 ns)
benchmarking 800 2
time 29.17 s (28.28 s .. 29.87 s)
1.000 R² (1.000 R² .. 1.000 R²)
mean 29.26 s (29.08 s .. 29.35 s)
std dev 159.2 ms (0.0 s .. 165.2 ms)
variance introduced by outliers: 19% (moderately inflated)
Use libraries
Now, the low-hanging fruit is to use existing implementations of things and hope their authors did something better than us. To that end, we'll use the standard function (^) instead of pow, and integerRoot from the arithmoi package instead of root. Additionally, we'll swap out the lazy foldr for a strict foldl. For my own sanity, I also reformatted the very long line into short ones. The full result now looks like this:
import Criterion.Main
import Data.List
import Math.NumberTheory.Powers
sum' :: Num a => [a] -> a
sum' = foldl' (+) 0
findNums :: Int -> Int -> Int
findNums a b = length
[ xs
| xs <- drop 1 . subsequences $ [x ^ b | x <- [1..c]]
, sum' xs == a
] where c = integerRoot b a
main = defaultMain
[ bench "100 2" (nf (uncurry findNums) (100, 2))
, bench "800 2" (nf (uncurry findNums) (800, 2))
]
Benchmark results look like this now:
benchmarking 100 2
time 722.8 ns (721.3 ns .. 724.3 ns)
1.000 R² (1.000 R² .. 1.000 R²)
mean 722.6 ns (721.4 ns .. 724.1 ns)
std dev 4.440 ns (3.738 ns .. 5.674 ns)
benchmarking 800 2
time 17.16 s (16.93 s .. 17.64 s)
1.000 R² (1.000 R² .. 1.000 R²)
mean 17.05 s (16.99 s .. 17.15 s)
std dev 88.10 ms (0.0 s .. 94.58 ms)
A little under twice as fast with very little effort. Nice!
Better algorithm
A significant problem with subsequences is that, even if we compute that sum' [x,y,z] > a, we still look at all the longer subsequences that start with [x,y,z]. Given the structure of subsequences' return type, there's not much we can do about that; so let's design an implementation that gives us a bit more structure. We'll build a tree, where paths from the root to any internal node give us a subsequence.
import Data.Tree
subsequences :: [a] -> Forest a
subsequences [] = []
subsequences (x:xs) = Node x rest : rest where
rest = subsequences xs
(Just for fun, this produces exponentially large semantic trees with very small space usage -- roughly as much space as the original list -- due to aggressive subtree sharing.) What's cool about this representation is if we break off the search, we cut off huge swaths of uninteresting results. This can be realized by implementing something like takeWhile for lists:
takeWhileTree :: Monoid m => (m -> Bool) -> Forest m -> Forest m
takeWhileTree predicate = goForest mempty where
goForest m forest = forest >>= goTree m
goTree m (Node m' children) =
[Node m (goForest (m <> m') children) | predicate m']
Let's give it a try. Complete code is now:
import Criterion.Main
import Data.Foldable
import Data.Monoid
import Data.Tree
import Math.NumberTheory.Powers
subsequencesTree :: [a] -> Forest a
subsequencesTree [] = []
subsequencesTree (x:xs) = Node x rest : rest where
rest = subsequencesTree xs
takeWhileTree :: Monoid m => (m -> Bool) -> Forest m -> Forest m
takeWhileTree predicate = goForest mempty where
goForest m forest = forest >>= goTree m
goTree m (Node m' children) = let m'' = m <> m' in
[Node m' (goForest m'' children) | predicate m'']
leaves :: Forest a -> [[a]]
leaves [] = [[]]
leaves forest = do
Node x children <- forest
xs <- leaves children
return (x:xs)
findNums :: Int -> Int -> Int
findNums a b = length
[ xs
| xs <- leaves
. takeWhileTree (<= Sum a)
. subsequencesTree
$ [Sum (x ^ b) | x <- [1..c]]
, fold xs == Sum a
] where c = integerRoot b a
main = defaultMain
[ bench "100 2" (nf (uncurry findNums) (100, 2))
, bench "800 2" (nf (uncurry findNums) (800, 2))
]
This looks like a lot of work, but from the timings, it really pays off:
benchmarking 100 2
time 16.67 μs (16.53 μs .. 16.77 μs)
0.999 R² (0.999 R² .. 1.000 R²)
mean 16.60 μs (16.52 μs .. 16.72 μs)
std dev 325.4 ns (270.5 ns .. 444.1 ns)
variance introduced by outliers: 17% (moderately inflated)
benchmarking 800 2
time 22.59 ms (22.26 ms .. 22.89 ms)
0.999 R² (0.999 R² .. 1.000 R²)
mean 22.44 ms (22.34 ms .. 22.57 ms)
std dev 260.3 μs (191.6 μs .. 332.2 μs)
That's a speedup factor of about 1000 on findNums 800 2.
Parallelization
I had a go at parallelizing this by using concat and parMap in takeWhileTree instead of (>>=), so that separate branches of the tree would be explored in parallel. In every case the overhead of parallelizing far outweighed the benefit of having several threads. Good thing we put that benchmark in place at the beginning!
回答2:
As you suggest, there is some room for optimization here without resorting to parallelizing things (which, keep in mind, can at best give a 4x speedup if you're going from one to four parallel threads).
What the subsequences function is doing is essentially going through the list, and for each element it creates two execution branches: one where that element is included, and one where it isn't. I.e., subsequences [1,2,3] does:
start
/-------/ \-------\ (take 1 or not)
[1,..] [..]
/ \ / \ (take 2 or not)
[1,2,..] [1,..] [2,..] [..]
/ \ / \ / \ / \ (take 3 or not)
[1,2,3] [1,2] [1,3] [1] [2,3] [2] [3] []
The result of subsequences [1,2,3] is then a list containing the leaf nodes at the bottom.
Now, at each of the intermediate nodes (i.e. [1,2,..]), we can check the result of applying the value function (i.e., the sum of squares or cubes or etc.) to the numbers already taken. If we're already above the target, there's no point in continuing that branch. If we write this subsequence generation logic by ourselves, we can do that:
findNums :: Int -> Int -> Int
findNums a b = findNums' a b 1 0
findNums' :: Int -> Int -> Int -> Int -> Int
findNums' a b c s
| s + c^b > a = 0
| s + c^b == a = 1
| otherwise = findNums' a b (c+1) (s + c^b) +
findNums' a b (c+1) s
Here c is our counter and s is the sum of the powers of the numbers we have picked. There are three cases in findNums':
In the first case, we check whether including this number would make us go above target. In that case, this branch is not going to give any valid results, so we terminate it and indicate that it contains no solution by returning 0.
In the second case, we check whether including this number would put us right on spot. In that case we return 1, essentially noting that we have found a solution.
If none of these are true, we recurse further with two different branches: one where we add c^b to our sum, and one where we refrain from doing so. We add the results together, which means that the result here will be the number of branches below this point that have found a valid solution.
回答3:
In this case it is useful to write a function which returns the actual sequences because that function can be written recursively in terms of itself.
To simplify things, let's just consider sums of squares. Also, we will first consider ordered sequences (with repeated values allowed); later we will look at how to modify the algorithm to produce only unordered sequences without any repeated numbers.
Here is our first attempt. The algorithm is based on this idea:
Idea 1:
To obtain a sequence whose sum of squares is n, first pick a value c and a sequence xs whose sum of squares is n-c*c and put the two together.
-- an integer sqrt function
isqrt n = floor $ (sqrt (fromIntegral n) :: Double)
pows2a :: Int -> [ [Int] ]
pows2a n
| n < 0 = []
| n == 0 = [ [] ]
| otherwise = [ (c:xs) | c <- [start,start-1..1], xs <- pows2a (n-c*c) ]
where start = isqrt n
This works, but returns permutations of solutions as well as solutions with
repeated elements - e.g. pos2a 6 returns [2,1,1], [1,2,1], [1,1,2] and [1,1,1,1,1,1].
To only return unordered sequences (without repetition) we use this idea:
Idea 2:
To obtain a sequence whose sum of squares is n, first pick a value c and a sequence xs whose sum of squares is n-c*c and all of whose elements are < c and put the two together.
This is just a slight modification of our first algorithm:
pows2b :: Int -> [[Int]]
pows2b n
| n < 0 = []
| n == 0 = [ [] ]
| otherwise = [ (c:xs) | c <- [start, start-1..1], xs <- pows2b (n-c*c), all (< c) xs ]
where
start = isqrt n
This works but a call like pows2b 100 takes a long time to complete because we are making calls to pows2b with the same argument multiple times.
We can solve that problem by memoizing the results, and this is what pows2c does:
powslist = map pows2c [0..]
pows2c n
| n == 0 = [ [] ]
| otherwise = [ (c:xs) | c <- [s,s-1..1], xs <- powslist !! (n-c*c), all (< c) xs ]
where s = isqrt n
Here the recursive call with argument n-c*c is replaced by a lookup into a list which is a way of caching the answer.
来源:https://stackoverflow.com/questions/27892435/optimization-possible-or-use-parallel-computing