Simplest way to get the top n elements of a Scala Iterable

前端 未结 9 645
佛祖请我去吃肉
佛祖请我去吃肉 2020-11-29 02:48

Is there a simple and efficient solution to determine the top n elements of a Scala Iterable? I mean something like

iter.toList.sortBy(_.myAttr).take(2)


        
相关标签:
9条回答
  • 2020-11-29 03:14

    You don't need to sort the entire collection in order to determine the top N elements. However, I don't believe that this functionality is supplied by the raw library, so you would have to roll you own, possibly using the pimp-my-library pattern.

    For example, you can get the nth element of a collection as follows:

      class Pimp[A, Repr <% TraversableLike[A, Repr]](self : Repr) {
    
        def nth(n : Int)(implicit ord : Ordering[A]) : A = {
          val trav : TraversableLike[A, Repr] = self
          var ltp : List[A] = Nil
          var etp : List[A] = Nil
          var mtp : List[A] = Nil
          trav.headOption match {
            case None      => error("Cannot get " + n + " element of empty collection")
            case Some(piv) =>
              trav.foreach { a =>
                val cf = ord.compare(piv, a)
                if (cf == 0) etp ::= a
                else if (cf > 0) ltp ::= a
                else mtp ::= a
              }
              if (n < ltp.length)
                new Pimp[A, List[A]](ltp.reverse).nth(n)(ord)
              else if (n < (ltp.length + etp.length))
                piv
              else
                new Pimp[A, List[A]](mtp.reverse).nth(n - ltp.length - etp.length)(ord)
          }
        }
      }
    

    (This is not very functional; sorry)

    It's then trivial to get the top n elements:

    def topN(n : Int)(implicit ord : Ordering[A], bf : CanBuildFrom[Repr, A, Repr]) ={
      val b = bf()
      val elem = new Pimp[A, Repr](self).nth(n)(ord)
      import util.control.Breaks._
      breakable {
        var soFar = 0
        self.foreach { tt =>
          if (ord.compare(tt, elem) < 0) {
             b += tt
             soFar += 1
          }
        }
        assert (soFar <= n)
        if (soFar < n) {
          self.foreach { tt =>
            if (ord.compare(tt, elem) == 0) {
              b += tt
              soFar += 1
            }
            if (soFar == n) break
          }
        }
    
      }
      b.result()
    }
    

    Unfortunately I'm having trouble getting this pimp to be discovered via this implicit:

    implicit def t2n[A, Repr <% TraversableLike[A, Repr]](t : Repr) : Pimp[A, Repr] 
      = new Pimp[A, Repr](t)
    

    I get this:

    scala> List(4, 3, 6, 7, 1, 2, 8, 5).topN(4)
    <console>:9: error: could not find implicit value for evidence parameter of type (List[Int]) => scala.collection.TraversableLike[A,List[Int]]
       List(4, 3, 6, 7, 1, 2, 8, 5).topN(4)
           ^
    

    However, the code actually works OK:

    scala> new Pimp(List(4, 3, 6, 7, 1, 2, 8, 5)).topN(4)
    res3: List[Int] = List(3, 1, 2, 4)
    

    And

    scala> new Pimp("ioanusdhpisjdmpsdsvfgewqw").topN(6)
    res2: java.lang.String = affffdfe
    
    0 讨论(0)
  • 2020-11-29 03:15

    If the goal is to not sort the whole list then you could do something like this (of course it could be optimized a tad so that we don't change the list when the number clearly shouldn't be there):

    List(1,6,3,7,3,2).foldLeft(List[Int]()){(l, n) => (n :: l).sorted.take(2)}
    
    0 讨论(0)
  • 2020-11-29 03:17

    Here's another solution that is simple and has pretty good performance.

    def pickTopN[T](k: Int, iterable: Iterable[T])(implicit ord: Ordering[T]): Seq[T] {
      val q = collection.mutable.PriorityQueue[T](iterable.toSeq:_*)
      val end = Math.min(k, q.size)
      (1 to end).map(_ => q.dequeue())
    }
    

    The Big O is O(n + k log n), where k <= n. So the performance is linear for small k and at worst n log n.

    The solution can also be optimized to be O(k) for memory but O(n log k) for performance. The idea is to use a MinHeap to track only the top k items at all times. Here's the solution.

    def pickTopN[A, B](n: Int, iterable: Iterable[A], f: A => B)(implicit ord: Ordering[B]): Seq[A] = {
      val seq = iterable.toSeq
      val q = collection.mutable.PriorityQueue[A](seq.take(n):_*)(ord.on(f).reverse) // initialize with first n
    
      // invariant: keep the top k scanned so far
      seq.drop(n).foreach(v => {
        q += v
        q.dequeue()
      })
    
      q.dequeueAll.reverse
    }
    
    0 讨论(0)
  • 2020-11-29 03:19

    For small values of n and large lists, getting the top n elements can be implemented by picking out the max element n times:

    def top[T](n:Int, iter:Iterable[T])(implicit ord: Ordering[T]): Iterable[T] = {
      def partitionMax(acc: Iterable[T], it: Iterable[T]): Iterable[T]  = {
        val max = it.max(ord)
        val (nextElems, rest) = it.partition(ord.gteq(_, max))
        val maxElems = acc ++ nextElems
        if (maxElems.size >= n || rest.isEmpty) maxElems.take(n)
        else partitionMax(maxElems, rest)
      }
      if (iter.isEmpty) iter.take(0)
      else partitionMax(iter.take(0), iter)
    }
    

    This does not sort the entire list and takes an Ordering. I believe every method I call in partitionMax is O(list size) and I only expect to call it n times at most, so the overall efficiency for small n will be proportional to the size of the iterator.

    scala> top(5, List.range(1,1000000))
    res13: Iterable[Int] = List(999999, 999998, 999997, 999996, 999995)
    
    scala> top(5, List.range(1,1000000))(Ordering[Int].on(- _))
    res14: Iterable[Int] = List(1, 2, 3, 4, 5)
    

    You could also add a branch for when n gets close to size of the iterable, and switch to iter.toList.sortBy(_.myAttr).take(n).

    It does not return the type of collection provided, but you can look at How do I apply the enrich-my-library pattern to Scala collections? if this is a requirement.

    0 讨论(0)
  • 2020-11-29 03:19

    An optimised solution using PriorityQueue with Time Complexity of O(nlogk). In the approach given in the update, you are sorting the sofar list every time which is not needed and below it is optimised by using PriorityQueue.

    import scala.language.implicitConversions
    import scala.language.reflectiveCalls
    import collection.mutable.PriorityQueue
    implicit def iterExt[A](iter: Iterable[A]) = new {
        def top[B](n: Int, f: A => B)(implicit ord: Ordering[B]) : List[A] = {
            def updateSofar (sofar: PriorityQueue[A], el: A): PriorityQueue[A] = {
                if (ord.compare(f(el), f(sofar.head)) < 0){
                    sofar.dequeue
                    sofar.enqueue(el)
                }
                sofar
            }
    
            val (sofar, rest) = iter.splitAt(n)
            (PriorityQueue(sofar.toSeq:_*)( Ordering.by( (x :A) => f(x) ) ) /: rest) (updateSofar (_, _)).dequeueAll.toList.reverse
        }
    }
    
    case class A(s: String, i: Int)
    val li = List (4, 3, 6, 7, 1, 2, 9, 5).map(i => A(i.toString(), i))
    println(li.top(3, -_.i))
    
    0 讨论(0)
  • 2020-11-29 03:21

    Here is asymptotically O(n) solution.

    def top[T](data: List[T], n: Int)(implicit ord: Ordering[T]): List[T] = {
        require( n < data.size)
    
        def partition_inner(shuffledData: List[T], pivot: T): List[T] = 
          shuffledData.partition( e => ord.compare(e, pivot) > 0 ) match {
              case (left, right) if left.size == n => left
              case (left, x :: rest) if left.size < n => 
                partition_inner(util.Random.shuffle(data), x)
              case (left @ y :: rest, right) if left.size > n => 
                partition_inner(util.Random.shuffle(data), y)
          }
    
         val shuffled = util.Random.shuffle(data)
         partition_inner(shuffled, shuffled.head)
    }
    
    scala> top(List.range(1,10000000), 5)
    

    Due to recursion, this solution will take longer than some non-linear solutions above and can cause java.lang.OutOfMemoryError: GC overhead limit exceeded. But slightly more readable IMHO and functional style. Just for job interview ;).

    What is more important, that this solution can be easily parallelized.

    def top[T](data: List[T], n: Int)(implicit ord: Ordering[T]): List[T] = {
        require( n < data.size)
    
        @tailrec
        def partition_inner(shuffledData: List[T], pivot: T): List[T] = 
          shuffledData.par.partition( e => ord.compare(e, pivot) > 0 ) match {
              case (left, right) if left.size == n => left.toList
              case (left, right) if left.size < n => 
                partition_inner(util.Random.shuffle(data), right.head)
              case (left, right) if left.size > n => 
                partition_inner(util.Random.shuffle(data), left.head)
          }
    
         val shuffled = util.Random.shuffle(data)
         partition_inner(shuffled, shuffled.head)
    }
    
    0 讨论(0)
提交回复
热议问题