How to maximize the weight of a tree by removing subtrees [closed]

只谈情不闲聊 提交于 2019-11-29 09:02:09

There's a straightforward recursive algorithm for this. The most profitable pruning you can perform on a tree is either to perform the most profitable pruning on all of its direct subtrees, or to just prune away the whole tree. This can be translated directly to code.

Assuming the tree has been processed into a data structure where every node has a value attribute representing the node's value and a children attribute storing a list of the node's child nodes, the following Python function would compute the max profit:

def max_profit(node):
    return max(
        -X,
        node.value + sum(map(max_profit, node.children)))

with the two options in the max call representing the choice to either prune the whole tree away at the root, or to keep the root and process the subtrees.

The idea is to parse the tree, and see which subtree could be removed such that the profit will increase compared to its initial state. Do this analysis for every node before removing anything. Then, remove the subtrees that increases the most the profit. We can do it in two passes:

1) Doing a depth-first traversal of the tree (leaves first, then slowly going back towards the root), calculate the profit gain of each node as the G(i)=-A(i)+G(j)+G(k)+..., where i is the current node, and j,k,... are the children. In other words, the profit gain is the added value if we remove the current node.

During the same traversal, also compute the maximum profit gain of the node and its children. This will tell us if it is more profitable to remove a node or if it is more profitable to remove a subtree of this node. We compute the maximum profit gain as M(i) = max(G(i),M(j),M(k),...), where i,j,k,... are defined as above. If a child does not exist, just remove it from the max equation.

2) Doing a breadth-first traversal of the tree, we remove a node i (and its subtree) if G(i) == M(i) and G(i) >= X.

def compute_gain(node):
    map(compute_gain, node.children)
    node.gain = -node.value + sum([c.gain for c in node.children])
    node.max_gain = max(node.gain, max([c.max_gain for c in node.children]))

def prune_tree(node):
    if node.gain >= X and node.max_gain == node.gain:
        k += 1
        return False
    node.children = [c for c in node.children if prune_tree(c) == True]
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!