Functional style early exit from depth-first recursion

走远了吗. 提交于 2019-12-03 07:23:55

It can be done: you just have to write some code to actually iterate through the children in the way you want (as opposed to relying on for).

More explicitly, you'll have to write code to iterate through a list of children and check if the "depth" crossed your threshold. Here's some Haskell code (I'm really sorry, I'm not fluent in Scala, but this can probably be easily transliterated):

http://ideone.com/O5gvhM

In this code, I've basically replaced the for loop for an explicit recursive version. This allows me to stop the recursion if the number of visited nodes is already too deep (i.e., limit is not positive). When I recurse to examine the next child, I subtract the number of nodes the dfs of the previous child visited and set this as the limit for the next child.

Functional languages are fun, but they're a huge leap from imperative programming. It really makes you pay attention to the concept of state, because all of it is excruciatingly explicit in the arguments when you go functional.

EDIT: Explaining this a bit more.

I ended up converting from "print just the leaf nodes" (which was the original algorithm from the OP) to "print all nodes". This enabled me to have access to the number of nodes the subcall visited through the length of the resulting list. If you want to stick to the leaf nodes, you'll have to carry around how many nodes you have already visited:

http://ideone.com/cIQrna

EDIT again To clear up this answer, I'm putting all the Haskell code on ideone, and I've transliterated my Haskell code to Scala, so this can stay here as the definite answer to the question:

case class Node[T](label:T, children:Seq[Node[T]])

case class TraversalResult[T](num_visited:Int, labels:Seq[T])

def dfs[T](node:Node[T], limit:Int):TraversalResult[T] =
    limit match {
        case 0     => TraversalResult(0, Nil)
        case limit => 
            node.children match {
                case Nil => TraversalResult(1, List(node.label))
                case children => {
                    val result = traverse(node.children, limit - 1)
                    TraversalResult(result.num_visited + 1, result.labels)
                }
            }
    }

def traverse[T](children:Seq[Node[T]], limit:Int):TraversalResult[T] =
    limit match {
        case 0     => TraversalResult(0, Nil)
        case limit =>
            children match {
                case Nil => TraversalResult(0, Nil)
                case first :: rest => {
                    val trav_first = dfs(first, limit)
                    val trav_rest = 
                        traverse(rest, limit - trav_first.num_visited)
                    TraversalResult(
                        trav_first.num_visited + trav_rest.num_visited,
                        trav_first.labels ++ trav_rest.labels
                    )
                }
            }
    }

val n = Node(0, List(
    Node(1, List(Node(2, Nil), Node(3, Nil))),
    Node(4, List(Node(5, List(Node(6, Nil))))),
    Node(7, Nil)
))
for (i <- 1 to 8)
    println(dfs(n, i))

Output:

TraversalResult(1,List())
TraversalResult(2,List())
TraversalResult(3,List(2))
TraversalResult(4,List(2, 3))
TraversalResult(5,List(2, 3))
TraversalResult(6,List(2, 3))
TraversalResult(7,List(2, 3, 6))
TraversalResult(8,List(2, 3, 6, 7))

P.S. this is my first attempt at Scala, so the above probably contains some horrid non-idiomatic code. I'm sorry.

You can convert breadth into depth by passing along an index or taking the tail:

def suml(xs: List[Int], total: Int = 0) = xs match {
  case Nil => total
  case x :: rest => suml(rest, total+x)
}

def suma(xs: Array[Int], from: Int = 0, total: Int = 0) = {
  if (from >= xs.length) total
  else suma(xs, from+1, total + xs(from))
}

In the latter case, you already have something to limit your breadth if you want; in the former, just add a width or somesuch.

W.P. McNeill

The following implements a lazy depth-first search over nodes in a tree.

import collection.TraversableView
case class Node[T](label: T, ns: Node[T]*)
def dfs[T](r: Node[T]): TraversableView[Node[T], Traversable[Node[T]]] =
  (Traversable[Node[T]](r).view /: r.ns) {
    (a, b) => (a ++ dfs(b)).asInstanceOf[TraversableView[Node[T], Traversable[Node[T]]]]
  }

This prints the labels of all the nodes in depth-first order.

val r = Node('a, Node('b, Node('d), Node('e, Node('f))), Node('c))
dfs(r).map(_.label).force
// returns Traversable[Symbol] = List('a, 'b, 'd, 'e, 'f, 'c)

This does the same thing, quitting after 3 nodes have been visited.

dfs(r).take(3).map(_.label).force
// returns Traversable[Symbol] = List('a, 'b, 'd)

If you want only leaf nodes you can use filter, and so forth.

Note that the fold clause of the dfs function requires an explicit asInstanceOf cast. See "Type variance error in Scala when doing a foldLeft over Traversable views" for a discussion of the Scala typing issues that necessitate this.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!