Algorithm to print all paths with a given sum in a binary tree

折月煮酒 提交于 2019-11-30 11:00:58

问题


The following is an interview question.

You are given a binary tree (not necessarily BST) in which each node contains a value. Design an algorithm to print all paths which sum up to that value. Note that it can be any path in the tree - it does not have to start at the root.

Although I am able to find all paths in tree that start at the root have the given sum, I am not able to do so for paths not not starting at the root.


回答1:


Well, this is a tree, not a graph. So, you can do something like this:

Pseudocode:

global ResultList

function ProcessNode(CurrentNode, CurrentSum)
    CurrentSum+=CurrentNode->Value
    if (CurrentSum==SumYouAreLookingFor) AddNodeTo ResultList
    for all Children of CurrentNode
          ProcessNode(Child,CurrentSum)

Well, this gives you the paths that start at the root. However, you can just make a tiny change:

    for all Children of CurrentNode
          ProcessNode(Child,CurrentSum)
          ProcessNode(Child,0)

You might need to think about it for a second (I'm busy with other things), but this should basically run the same algorithm rooted at every node in the tree

EDIT: this actually gives the "end node" only. However, as this is a tree, you can just start at those end nodes and walk back up until you get the required sum.

EDIT 2: and, of course, if all values are positive then you can abort the descent if your current sum is >= the required one




回答2:


Here's an O(n + numResults) answer (essentially the same as @Somebody's answer, but with all issues resolved):

  1. Do a pre-order, in-order, or post-order traversal of the binary tree.
  2. As you do the traversal, maintain the cumulative sum of node values from the root node to the node above the current node. Let's call this value cumulativeSumBeforeNode.
  3. When you visit a node in the traversal, add it to a hashtable at key cumulativeSumBeforeNode (the value at that key will be a list of nodes).
  4. Compute the difference between cumulativeSumBeforeNode and the target sum. Look up this difference in the hash table.
  5. If the hash table lookup succeeds, it should produce a list of nodes. Each one of those nodes represents the start node of a solution. The current node represents the end node for each corresponding start node. Add each [start node, end node] combination to your list of answers. If the hash table lookup fails, do nothing.
  6. When you've finished visiting a node in the traversal, remove the node from the list stored at key cumulativeSumBeforeNode in the hash table.

Code:

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class BinaryTreePathsWithSum {
    public static void main(String[] args) {
        BinaryTreeNode a = new BinaryTreeNode(5);
        BinaryTreeNode b = new BinaryTreeNode(16);
        BinaryTreeNode c = new BinaryTreeNode(16);
        BinaryTreeNode d = new BinaryTreeNode(4);
        BinaryTreeNode e = new BinaryTreeNode(19);
        BinaryTreeNode f = new BinaryTreeNode(2);
        BinaryTreeNode g = new BinaryTreeNode(15);
        BinaryTreeNode h = new BinaryTreeNode(91);
        BinaryTreeNode i = new BinaryTreeNode(8);

        BinaryTreeNode root = a;
        a.left = b;
        a.right = c;
        b.right = e;
        c.right = d;
        e.left = f;
        f.left = g;
        f.right = h;
        h.right = i;

        /*
                5
              /   \
            16     16
              \     \
              19     4
              /
             2
            / \
           15  91
                \
                 8
        */

        List<BinaryTreePath> pathsWithSum = getBinaryTreePathsWithSum(root, 112); // 19 => 2 => 91

        System.out.println(Arrays.toString(pathsWithSum.toArray()));
    }

    public static List<BinaryTreePath> getBinaryTreePathsWithSum(BinaryTreeNode root, int sum) {
        if (root == null) {
            throw new IllegalArgumentException("Must pass non-null binary tree!");
        }

        List<BinaryTreePath> paths = new ArrayList<BinaryTreePath>();
        Map<Integer, List<BinaryTreeNode>> cumulativeSumMap = new HashMap<Integer, List<BinaryTreeNode>>();

        populateBinaryTreePathsWithSum(root, 0, cumulativeSumMap, sum, paths);

        return paths;
    }

    private static void populateBinaryTreePathsWithSum(BinaryTreeNode node, int cumulativeSumBeforeNode, Map<Integer, List<BinaryTreeNode>> cumulativeSumMap, int targetSum, List<BinaryTreePath> paths) {
        if (node == null) {
            return;
        }

        addToMap(cumulativeSumMap, cumulativeSumBeforeNode, node);

        int cumulativeSumIncludingNode = cumulativeSumBeforeNode + node.value;
        int sumToFind = cumulativeSumIncludingNode - targetSum;

        if (cumulativeSumMap.containsKey(sumToFind)) {
            List<BinaryTreeNode> candidatePathStartNodes = cumulativeSumMap.get(sumToFind);

            for (BinaryTreeNode pathStartNode : candidatePathStartNodes) {
                paths.add(new BinaryTreePath(pathStartNode, node));
            }
        }

        populateBinaryTreePathsWithSum(node.left, cumulativeSumIncludingNode, cumulativeSumMap, targetSum, paths);
        populateBinaryTreePathsWithSum(node.right, cumulativeSumIncludingNode, cumulativeSumMap, targetSum, paths);

        removeFromMap(cumulativeSumMap, cumulativeSumBeforeNode);
    }

    private static void addToMap(Map<Integer, List<BinaryTreeNode>> cumulativeSumMap, int cumulativeSumBeforeNode, BinaryTreeNode node) {
        if (cumulativeSumMap.containsKey(cumulativeSumBeforeNode)) {
            List<BinaryTreeNode> nodes = cumulativeSumMap.get(cumulativeSumBeforeNode);
            nodes.add(node);
        } else {
            List<BinaryTreeNode> nodes = new ArrayList<BinaryTreeNode>();
            nodes.add(node);
            cumulativeSumMap.put(cumulativeSumBeforeNode, nodes);
        }
    }

    private static void removeFromMap(Map<Integer, List<BinaryTreeNode>> cumulativeSumMap, int cumulativeSumBeforeNode) {
        List<BinaryTreeNode> nodes = cumulativeSumMap.get(cumulativeSumBeforeNode);
        nodes.remove(nodes.size() - 1);
    }

    private static class BinaryTreeNode {
        public int value;
        public BinaryTreeNode left;
        public BinaryTreeNode right;

        public BinaryTreeNode(int value) {
            this.value = value;
        }

        public String toString() {
            return this.value + "";
        }

        public int hashCode() {
            return Integer.valueOf(this.value).hashCode();
        }

        public boolean equals(Object other) {
            return this == other;
        }
    }

    private static class BinaryTreePath {
        public BinaryTreeNode start;
        public BinaryTreeNode end;

        public BinaryTreePath(BinaryTreeNode start, BinaryTreeNode end) {
            this.start = start;
            this.end = end;
        }

        public String toString() {
            return this.start + " to " + this.end;
        }
    }
}



回答3:


Based on Christian's answer above:

public void printSums(Node n, int sum, int currentSum, String buffer) {
     if (n == null) {
         return;
     }
     int newSum = currentSum + n.val;
     String newBuffer = buffer + " " + n.val;
     if (newSum == sum) {
         System.out.println(newBuffer);
     }
     printSums(n.left, sum, newSum, newBuffer);
     printSums(n.right, sum, newSum, newBuffer);
     printSums(n.left, sum, 0, "");
     printSums(n.right, sum, 0, "");
} 

printSums(root, targetSum, 0, "");



回答4:


Here is an approach with nlogn complexity.

  1. Traverse the tree with inorder.
  2. At the same time maintain all the nodes along with the cumulative sum in a Hashmap<CumulativeSum, reference to the corresponding node>.
  3. Now at a given node calculate cumulative sum from root to till the node say this be SUM.
  4. Now look for the value SUM-K in the HashMap.
  5. If the entry exists take the corresponding node reference in the HashMap.
  6. Now we have a valid path from the node reference to the current node.



回答5:


A clean solution in JAVA. Using internal recursive calls keeping track of the traversed paths.

private static void pathSunInternal(TreeNode root, int sum, List<List<Integer>> result, List<Integer> path){
    if(root == null)
        return;     
    path.add(root.val);
    if(sum == root.val && root.left == null && root.right == null){
        result.add(path);
    }

    else if(sum != root.val && root.left == null && root.right == null)
        return;
    else{
        List<Integer> leftPath = new ArrayList<>(path);
        List<Integer> rightPath = new ArrayList<>(path);
        pathSunInternal(root.left, sum - root.val, result, leftPath);
        pathSunInternal(root.right, sum - root.val, result, rightPath);
    }
}

public static List<List<Integer>> pathSum(TreeNode root, int sum) {
    List<List<Integer>> result = new ArrayList<>(); 
    List<Integer> path = new ArrayList<>();
    pathSunInternal(root, sum, result, path);       
    return result;
}



回答6:


Update: I see now that my answer does not directly answer your question. I will leave it here if it proves useful, but it needs no upvotes. If not useful, I'll remove it. I do agree with @nhahtdh, however, when he advises, "Reuse your algorithm with all other nodes as root."

One suspects that the interviewer is fishing for recursion here. Don't disappoint him!

Given a node, your routine should call itself against each of its child nodes, if it has any, and then add the node's own datum to the return values, then return the sum.

For extra credit, warn the interviewer that your routine can fail, entering an bottomless, endless recursion, if used on a general graph rather than a binary tree.




回答7:


One can reduce this tree to a weighted graph G, where each edge weight = sum of values in each of its nodes.

Then, run Floyd-Warshall algorithm on the graph G. By inspecting elements in the resulting matrix, we can get all pairs of nodes between which the total sum is equal to the desired sum.

Also, note that the shortest path the algorithm gives is also the only path between 2 nodes in this tree.

This is just another approach, not as efficient as a recursive approach.




回答8:


We can solve it with tree-structure dynamic programming, and both the time and space complexity is O(n^2), where n is the number of all the tree nodes.

The idea is as follows:

For a tree node, we keep a set recording all possible sums starting from u to its all descendants. Then recursively, any node's set can be updated by its two children, specifically, by merging two children's sets.

The pseudocode is:

bool findPath(Node u, Set uset, int finalSum) {
    Set lset, rset;
    if (findPath(u.left, lset, finalSum) || findPath(u.right, rset, finalSum)) return true;
    for (int s1 : lset) {
        if (finalSum - u.val - s1 == 0 || rset.contains(finalSum - u.val - s1)) return true;
        // first condition means a path from u to some descendant in u's left child
        // second condition means a path from some node in u's left child to some node in u's right child

        uset.insert(s1 + u.val); // update u's set
    }
    for (int s2 : rset) {
        if (finalSum - u.val - s2 == 0) return true;
        // don't forget the path from u to some descendant in u's right child
        uset.insert(s2 + u.val); // update u's set
    }
    return false;
}

I notice the original question is to find all paths, but the algorithm above is to find whether existed. I think the idea is similar, but this version makes the problem easier to explain :)




回答9:


public void printPath(N n) {
    printPath(n,n.parent);
}

private void printPath(N n, N endN) {
    if (n == null)
        return;
    if (n.left == null && n.right == null) {
        do {
            System.out.print(n.value);
            System.out.print(" ");
        } while ((n = n.parent)!=endN);
        System.out.println("");
        return;
    }
    printPath(n.left, endN);
    printPath(n.right, endN);
}

You can print tree path end the n node. like this printPath(n);




回答10:


void printpath(int sum,int arr[],int level,struct node * root)
{
  int tmp=sum,i;
  if(root == NULL)
  return;
  arr[level]=root->data;
  for(i=level;i>=0;i--)
  tmp-=arr[i];
  if(tmp == 0)
  print(arr,level,i+1);
  printpath(sum,arr,level+1,root->left);
  printpath(sum,arr,level+1,root->right);
}
 void print(int arr[],int end,int start)
{  

int i;
for(i=start;i<=end;i++)
printf("%d ",arr[i]);
printf("\n");
}

complexity(n logn) Space complexity(n)




回答11:


# include<stdio.h>
# include <stdlib.h>
struct Node
{
    int data;
    struct Node *left, *right;
};

struct Node * newNode(int item)
{
    struct Node *temp =  (struct Node *)malloc(sizeof(struct Node));
    temp->data = item;
    temp->left =  NULL;
    temp->right = NULL;
    return temp;
}
void print(int p[], int level, int t){
    int i;
    for(i=t;i<=level;i++){
        printf("\n%d",p[i]);
    }
}
void check_paths_with_given_sum(struct Node * root, int da, int path[100], int level){

     if(root == NULL)
        return ;
    path[level]=root->data;
    int i;int temp=0;
    for(i=level;i>=0;i--){
        temp=temp+path[i];
        if(temp==da){
            print(path,level,i);
        }
    }
        check_paths_with_given_sum(root->left, da, path,level+1);
        check_paths_with_given_sum(root->right, da, path,level+1);

}
int main(){
    int par[100];
 struct Node *root = newNode(10);
    root->left = newNode(2);
    root->right = newNode(4);
    root->left->left = newNode(1);
    root->right->right = newNode(5);
    check_paths_with_given_sum(root, 9, par,0);


}

This works.....




回答12:


https://codereview.stackexchange.com/questions/74957/find-all-the-paths-of-tree-that-add-to-a-input-value

I have attempted an answer, expecting code review. My code as well as the reviewers should be helpful source.




回答13:


Below is the solution using recurssion. We perform a in order traversal of the binary tree, as we move down a level we sum up the total path weight by adding the weight of the current level to the weights of previous levels of the tree, if we hit our sum we then print out the path. This solution will handle cases where we may have more than 1 solution along any given path path.

Assume you have a binary tree rooted at root.

#include <iostream>
#include <vector>
using namespace std;

class Node
{
private:
    Node* left;
    Node* right;
    int value;

public:
    Node(const int value)
    {
        left=NULL;
        right=NULL;
        this->value=value;
    }

    void setLeft(Node* left)
    {
        this->left=left;
    }

    void setRight(Node* right)
    {
        this->right = right;
    }

    Node* getLeft() const
    {
        return left;
    }

    Node* getRight() const
    {
        return right;
    }

    const int& getValue() const
    {
        return value;
    }
};

//get maximum height of the tree so we know how much space to allocate for our
//path vector

int getMaxHeight(Node* root)
{
    if (root == NULL)
        return 0;

    int leftHeight = getMaxHeight(root->getLeft());
    int rightHeight = getMaxHeight(root->getRight());

    return max(leftHeight, rightHeight) + 1;
}

//found our target sum, output the path
void printPaths(vector<int>& paths, int start, int end)
{
    for(int i = start; i<=end; i++)
        cerr<<paths[i]<< " ";

    cerr<<endl;
}

void generatePaths(Node* root, vector<int>& paths, int depth, const int sum)
{
    //base case, empty tree, no path
    if( root == NULL)
        return;

    paths[depth] = root->getValue();
    int total =0;

    //sum up the weights of the nodes in the path traversed
    //so far, if we hit our target, output the path
    for(int i = depth; i>=0; i--)
    {
        total += paths[i];
        if(total == sum)
            printPaths(paths, i, depth);
    }

    //go down 1 level where we will then sum up from that level
    //back up the tree to see if any sub path hits our target sum
    generatePaths(root->getLeft(), paths, depth+1, sum);
    generatePaths(root->getRight(), paths, depth+1, sum);
}

int main(void)
{
    vector<int> paths (getMaxHeight(&root));
    generatePaths(&root, paths, 0,0);
}

space complexity depends on the the height of the tree, assumming this is a balanced tree then space complexity is 0(log n) based on the depth of the recurssion stack. Time complexity O(n Log n) - based on a balanced tree where there are n nodes at each level and at each level n amount of work will be done(summing the paths). We also know the tree height is bounded by O(log n) for a balanced binary tree, so n amount of work done for each level on a balanced binary tree gives a run time of O( n log n)




回答14:


// assumption node have integer value other than zero
void printAllPaths(Node root, int sum , ArrayList<Integer> path) {

   if(sum == 0) {
      print(path); // simply print the arraylist
    }

   if(root ==null) {
     //traversed one end of the tree...just return
      return;
  }
    int data = root.data;
    //this node can be at the start, end or in middle of path only if it is       //less than the sum
    if(data<=sum) {
     list.add(data);
     //go left and right
    printAllPaths(root.left, sum-data ,  path);
    printAllPaths(root.right, sum-data ,  path);

    }
   //note it is not else condition to ensure root can start from anywhere
   printAllPaths(root.left, sum ,  path);
   printAllPaths(root.right, sum ,  path);
}



回答15:


I have improved some coding logic of answer by Arvind Upadhyay. Once the if loop done, you can not use the same list. So need to create the new list. Also, there is need to maintain count of level the logic go down for from current node to search path. If we do not find path, so before going to its children, we need to come out from the recursive call equal to count times.

int count =0;
public void printAllPathWithSum(Node node, int sum, ArrayList<Integer> list)
{   
    if(node == null)
        return;
    if(node.data<=sum)
    {
        list.add(node.data);
        if(node.data == sum)
            print(list);
        else
        {
            count ++;
            printAllPathWithSum(node.left, sum-node.data, list);
            printAllPathWithSum(node.right, sum-node.data, list);
            count --;
        }
    }
    if(count != 0)
        return ;


    printAllPathWithSum(node.left, this.sum, new ArrayList());
    if(count != 0)
        return;
    printAllPathWithSum(node.right, this.sum, new ArrayList());

}
public void print(List list)
{
    System.out.println("Next path");
    for(int i=0; i<list.size(); i++)
        System.out.print(Integer.toString((Integer)list.get(i)) + " ");
    System.out.println();
}

Check the full code at: https://github.com/ganeshzilpe/java/blob/master/Tree/BinarySearchTree.java




回答16:


Search:

Recursively traverse the tree, comparing with the input key, as in binary search tree.

If the key is found, move the target node (where the key was found) to the root position using splaysteps.

Pseudocode:


Algorithm: search (key)
Input: a search-key
1.   found = false;
2.   node = recursiveSearch (root, key)
3.   if found
4.     Move node to root-position using splaysteps;
5.     return value
6.   else
7.     return null
8.   endif
Output: value corresponding to key, if found.



Algorithm: recursiveSearch (node, key)
Input: tree node, search-key
1.   if key = node.key
2.     found = true
3.     return node
4.   endif
     // Otherwise, traverse further 
5.   if key < node.key
6.     if node.left is null
7.       return node
8.     else
9.       return recursiveSearch (node.left, key)
10.    endif
11.  else
12.    if node.right is null
13.      return node
14.    else
15.      return recursiveSearch (node.right, key)
16.    endif
17.  endif
Output: pointer to node where found; if not found, pointer to node for insertion.



回答17:


Since we need the paths having sum == k . I assume worst case complexity can be O(total_paths_in_tree) .

So why not generate every path and check for the sum , anyways it is a tree having negative numbers and is not even a binary search tree .

    struct node{
      int val;
      node *left,*right;

      node(int vl)
      {
        val = vl;
        left = NULL;
        right = NULL;
      }
   };


   vector<vector<int> > all_paths;
   vector<vector<int> > gen_paths(node* root)
   {
       if(root==NULL)
       {
          return vector<vector<int> > ();
       }

       vector<vector<int> >    left_paths = gen_paths(root->left);
       vector<vector<int> >    right_paths = gen_paths(root->right);

       left_paths.push_back(vector<int> ()); //empty path
       right_paths.push_back(vector<int> ());

       vector<vector<int> > paths_here;
       paths_here.clear();


       for(int i=0;i<left_paths.size();i++)
       {
           for(int j=0;j<right_paths.size();j++)
           {
              vector<int> vec;
              vec.clear();
              vec.insert(vec.end(), left_paths[i].begin(), left_paths[i].end());
             vec.push_back(root->val);
             vec.insert(vec.end(), right_paths[j].begin(), right_paths[j].end());
             paths_here.push_back(vec);
           }
        }

        all_paths.insert(all_paths.end(),paths_here.begin(),paths_here.end());

       vector<vector<int> > paths_to_extend;
       paths_to_extend.clear();

       for(int i=0;i<left_paths.size();i++)
       {
            paths_to_extend.push_back(left_paths[i]);
            paths_to_extend[i].push_back(root->val);
       }

       for(int i=0;i<right_paths.size();i++)
       {
           paths_to_extend.push_back(right_paths[i]);
           paths_to_extend[paths_to_extend.size()-1].push_back(root->val);
       }

       return paths_to_extend;
    }

For generating paths I have generated all left paths and all right paths And added the left_paths + node->val + right_paths to all_paths at each node. And have sent the paths which can still be extended .i.e all paths from both sides + node .



来源:https://stackoverflow.com/questions/11328358/algorithm-to-print-all-paths-with-a-given-sum-in-a-binary-tree

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!