Implementing List#flatMap

二次信任 提交于 2019-12-21 04:38:10

问题


Is there a better functional way to write flatMap?

def flatMap[A,B](list: List[A])(f: A => List[B]): List[B] =
    list.map(x => f(x)).flatten

Conceptually, I understand flatMap in terms of flatten.


回答1:


An alternate approach:

def flatMap[A, B](list: List[A])(f: A => List[B]): List[B] =
  list.foldLeft(List[B]())(_ ++ f(_))

I don't know about “better”. (And if we start talking about efficient implementation, that's another can of worms...)




回答2:


Just to flesh out the answers, you could also define this as a recursive function using pattern matching:

def flatMap[A, B](list: List[A])(f: A => List[B]): List[B] = list match {
  case (x::xs) => f(x) ++ flatMap(xs)(f)
  case _ => Nil
}

Or make it explicitly tail-recursive:

import scala.annotation.tailrec

def flatMapTailRec[A, B](list: List[A])(f: A => List[B]): List[B] = {
  @tailrec
  def _flatMap(result: List[B])(input: List[A])(f: A => List[B]): List[B] = input match {
    case (x::xs) => _flatMap(f(x) ++ result)(xs)(f)
    case _ => result
  }
  _flatMap(List[B]())(list)(f)
}

I did a bit of quick, non-rigorous benchmarking, using sample input of:

val input = (0 to 1000).map(_ => (0 to 1000).toList).toList

In order from fastest to slowest:

  • flatMap(input)(x => x)
    • 0.02069937453 seconds
  • flatMapTailRec(input)(x => x)
    • 0.02335651054 seconds
  • input.flatMap(x => x)
    • 0.0297564358 seconds
  • flatMapFoldLeft(input)(x => x)
    • 12.940458234 seconds

I'm a little surprised that foldLeft comes out so much slower than the others. It would be interested to see how flatMap is actually defined in the source. I tried looking myself, but there's too much to go through at the moment >_>.

Edit: As Daniel Sobral pointed out in the comments of another answer, these implementations are restricted to List[A]. You could write a more generic implementation that would work for any mappable type. The code would quickly become much more complicated.




回答3:


Your solution is already pretty functional, and emphasize where flatMap name comes from. Note that x => f(x) is f, so it boils down to:

list.map(f).flatten

Using foldLeft is a terrible idea, due to quadratic behavior induced in concatenation. E.g. (((a++b)++c)++d)++List() would iterate 4 times on a, 3 times on b, etc.

Fine would be foldRight.



来源:https://stackoverflow.com/questions/20436614/implementing-listflatmap

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!