问题
Is there a better functional way to write flatMap?
def flatMap[A,B](list: List[A])(f: A => List[B]): List[B] =
list.map(x => f(x)).flatten
Conceptually, I understand flatMap in terms of flatten.
回答1:
An alternate approach:
def flatMap[A, B](list: List[A])(f: A => List[B]): List[B] =
list.foldLeft(List[B]())(_ ++ f(_))
I don't know about “better”. (And if we start talking about efficient implementation, that's another can of worms...)
回答2:
Just to flesh out the answers, you could also define this as a recursive function using pattern matching:
def flatMap[A, B](list: List[A])(f: A => List[B]): List[B] = list match {
case (x::xs) => f(x) ++ flatMap(xs)(f)
case _ => Nil
}
Or make it explicitly tail-recursive:
import scala.annotation.tailrec
def flatMapTailRec[A, B](list: List[A])(f: A => List[B]): List[B] = {
@tailrec
def _flatMap(result: List[B])(input: List[A])(f: A => List[B]): List[B] = input match {
case (x::xs) => _flatMap(f(x) ++ result)(xs)(f)
case _ => result
}
_flatMap(List[B]())(list)(f)
}
I did a bit of quick, non-rigorous benchmarking, using sample input of:
val input = (0 to 1000).map(_ => (0 to 1000).toList).toList
In order from fastest to slowest:
flatMap(input)(x => x)- 0.02069937453 seconds
flatMapTailRec(input)(x => x)- 0.02335651054 seconds
input.flatMap(x => x)- 0.0297564358 seconds
flatMapFoldLeft(input)(x => x)- 12.940458234 seconds
I'm a little surprised that foldLeft comes out so much slower than the others. It would be interested to see how flatMap is actually defined in the source. I tried looking myself, but there's too much to go through at the moment >_>.
Edit: As Daniel Sobral pointed out in the comments of another answer, these implementations are restricted to List[A]. You could write a more generic implementation that would work for any mappable type. The code would quickly become much more complicated.
回答3:
Your solution is already pretty functional, and emphasize where flatMap name comes from. Note that x => f(x) is f, so it boils down to:
list.map(f).flatten
Using foldLeft is a terrible idea, due to quadratic behavior induced in concatenation. E.g. (((a++b)++c)++d)++List() would iterate 4 times on a, 3 times on b, etc.
Fine would be foldRight.
来源:https://stackoverflow.com/questions/20436614/implementing-listflatmap