Generating an optimal binary search tree (Cormen)

我怕爱的太早我们不能终老 提交于 2021-02-08 13:53:15

问题


I'm reading Cormen et al., Introduction to Algorithms (3rd ed.) (PDF), section 15.4 on optimal binary search trees, but am having some trouble implementing the pseudocode for the optimal_bst function in Python.

Here is the example I'm trying to apply the optimal BST to:

Let us define e[i,j] as the expected cost of searching an optimal binary search tree containing the keys labeled from i to j. Ultimately, we wish to compute e[1, n], where n is the number of keys (5 in this example). The final recursive formulation is:

which should be implemented by the following pseudocode:

Notice that the pseudocode interchangeably uses 1- and 0-based indexing, whereas Python uses only the latter. As a consequence I'm having trouble implementing the pseudocode. Here is what I have so far:

import numpy as np

p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)

e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n+1):
    for i in range(n-l+1):
        j = i + l
        e[i, j] = np.inf
        w[i, j] = w[i, j-1] + p[j-1] + q[j]
        for r in range(i, j+1):
            t = e[i-1, r-1] + e[r, j] + w[i-1, j]
            if t < e[i-1, j]:
                e[i-1, j] = t
                root[i-1, j] = r

print(w)
print(e)

However, if I run this the weights w get computed correctly, but the expected search values e remain 'stuck' at their initialized values:

[[ 0.05  0.3   0.45  0.55  0.7   1.  ]
 [ 0.    0.1   0.25  0.35  0.5   0.8 ]
 [ 0.    0.    0.05  0.15  0.3   0.6 ]
 [ 0.    0.    0.    0.05  0.2   0.5 ]
 [ 0.    0.    0.    0.    0.05  0.35]
 [ 0.    0.    0.    0.    0.    0.1 ]]
[[ 0.05   inf   inf   inf   inf   inf]
 [ 0.    0.1    inf   inf   inf   inf]
 [ 0.    0.    0.05   inf   inf   inf]
 [ 0.    0.    0.    0.05   inf   inf]
 [ 0.    0.    0.    0.    0.05   inf]
 [ 0.    0.    0.    0.    0.    0.1 ]]

What I expect is that e, w, and root be as follows:

I've been debugging this for a couple of hours by now and am still stuck. Can someone point out what is wrong with the Python code above?


回答1:


It appears to me that you made a mistake in the indices. I couldn't make it work as expected but the following code should give you an indication where I was heading at (there is probably an off by one somewhere):

import numpy as np

p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)

def get2(m, i, j):
    return m[i - 1, j - 1]


def set2(m, i, j, v):
    m[i - 1, j - 1] = v


def get1(m, i):
    return m[i - 1]


def set1(m, i, v):
    m[i - 1] = v


e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n + 1):
    for i in range(n - l + 2):
        j = i + l - 1
        set2(e, i, j, np.inf)
        set2(w, i, j, get2(w, i, j - 1) + get1(p, j) + get1(q, j))
        for r in range(i, j + 1):
            t = get2(e, i, r - 1) + get2(e, r + 1, j) + get2(w, i, j)
            if t < get2(e, i, j):
                set2(e, i, j, t)
                set2(root, i, j, r)

print(w)
print(e)

The result:

[[ 0.2   0.4   0.5   0.65  0.9   0.  ]
 [ 0.    0.2   0.3   0.45  0.7   0.  ]
 [ 0.    0.    0.1   0.25  0.5   0.  ]
 [ 0.    0.    0.    0.15  0.4   0.  ]
 [ 0.    0.    0.    0.    0.25  0.  ]
 [ 0.5   0.7   0.8   0.95  0.    0.3 ]]
[[ 0.2   0.6   0.8   1.2   1.95  0.  ]
 [ 0.    0.2   0.4   0.8   1.35  0.  ]
 [ 0.    0.    0.1   0.35  0.85  0.  ]
 [ 0.    0.    0.    0.15  0.55  0.  ]
 [ 0.    0.    0.    0.    0.25  0.  ]
 [ 0.7   1.2   1.5   2.    0.    0.3 ]]



回答2:


In the end I used pandas' Series and DataFrame objects initialized with custom index and columns to coerce the arrays to have the same indexing as in the pseudocode. After that, the pseudocode can be almost copy-pasted:

import numpy as np
import pandas as pd

P = [0.15, 0.10, 0.05, 0.10, 0.20]
Q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(P)

p = pd.Series(P, index=range(1, n+1))
q = pd.Series(Q)

e = pd.DataFrame(np.diag(Q), index=range(1, n+2))
w = pd.DataFrame(np.diag(Q), index=range(1, n+2))
root = pd.DataFrame(np.zeros((n, n)), index=range(1, n+1), columns=range(1, n+1))

for l in range(1, n+1):
    for i in range(1, n-l+2):
        j = i+l-1
        e.set_value(i, j, np.inf)
        w.set_value(i, j, w.get_value(i, j-1) + p[j] + q[j])
        for r in range(i, j+1):
            t = e.get_value(i, r-1) + e.get_value(r+1, j) + w.get_value(i, j)
            if t < e.get_value(i, j):
                e.set_value(i, j, t)
                root.set_value(i, j, r)

print(e)
print(w)
print(root)

which yields the expected results:

      0     1     2     3     4     5
1  0.05  0.45  0.90  1.25  1.75  2.75
2  0.00  0.10  0.40  0.70  1.20  2.00
3  0.00  0.00  0.05  0.25  0.60  1.30
4  0.00  0.00  0.00  0.05  0.30  0.90
5  0.00  0.00  0.00  0.00  0.05  0.50
6  0.00  0.00  0.00  0.00  0.00  0.10
      0    1     2     3     4     5
1  0.05  0.3  0.45  0.55  0.70  1.00
2  0.00  0.1  0.25  0.35  0.50  0.80
3  0.00  0.0  0.05  0.15  0.30  0.60
4  0.00  0.0  0.00  0.05  0.20  0.50
5  0.00  0.0  0.00  0.00  0.05  0.35
6  0.00  0.0  0.00  0.00  0.00  0.10
     1    2    3    4    5
1  1.0  1.0  2.0  2.0  2.0
2  0.0  2.0  2.0  2.0  4.0
3  0.0  0.0  3.0  4.0  5.0
4  0.0  0.0  0.0  4.0  5.0
5  0.0  0.0  0.0  0.0  5.0

I would still be interested in a solution with Numpy arrays, though, as this seems more elegant to me.




回答3:


Kurt, Thanks for your post! Yours is the only working implementation of this problem I found. I spent mucho tiempo wrestling with the indices. Here is my implementation with numpy arrays.

import numpy as np
import math

def optimalBST(p,q,n):

    e = np.zeros((n+1)**2).reshape(n+1,n+1)
    w = np.zeros((n+1)**2).reshape(n+1,n+1)
    root = np.zeros((n+1)**2).reshape(n+1,n+1)

    # Initialization
    for i in range(n+1):
        e[i,i] = q[i]
        w[i,i] = q[i]
    for i in range(0,n):
        root[i,i] = i+1

    for l in range(1,n+1):
        for i in range(0, n-l+1):
            j = i+l
            min_ = math.inf
            w[i,j] = w[i,j-1] + p[j] + q[j]
            for r in range(i,j):
                t = e[i, r-1+1] + e[r+1,j] +  w[i,j]
                if t < min_:
                    min_ = t                
                    e[i, j] = t
                    root[i, j-1] = r+1

    root_pruned = np.delete(np.delete(root, n, 1), n, 0)        # Trim last col & row.

    print("------ e -------")
    print(e)
    print("------ w -------")
    print(w)
    print("----- root -----")
    print(root_pruned)

def main():

    p = [0,.15,.1,.05,.1,.2]
    q = [.05,.1,.05,.05,.05,.1]
    n = len(p)-1

    optimalBST(p,q,n)

if __name__ == '__main__':
    main()

Output:

------ e -------
[[0.05 0.45 0.9  1.25 1.75 2.75]
 [0.   0.1  0.4  0.7  1.2  2.  ]
 [0.   0.   0.05 0.25 0.6  1.3 ]
 [0.   0.   0.   0.05 0.3  0.9 ]
 [0.   0.   0.   0.   0.05 0.5 ]
 [0.   0.   0.   0.   0.   0.1 ]]
------ w -------
[[0.05 0.3  0.45 0.55 0.7  1.  ]
 [0.   0.1  0.25 0.35 0.5  0.8 ]
 [0.   0.   0.05 0.15 0.3  0.6 ]
 [0.   0.   0.   0.05 0.2  0.5 ]
 [0.   0.   0.   0.   0.05 0.35]
 [0.   0.   0.   0.   0.   0.1 ]]
----- root -----
[[1. 1. 2. 2. 2.]
 [0. 2. 2. 2. 4.]
 [0. 0. 3. 4. 5.]
 [0. 0. 0. 4. 5.]
 [0. 0. 0. 0. 5.]]


来源:https://stackoverflow.com/questions/46160969/generating-an-optimal-binary-search-tree-cormen

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!