Is there a simple and efficient solution to determine the top n elements of a Scala Iterable? I mean something like
iter.toList.sortBy(_.myAttr).take(2)
For small values of n
and large lists, getting the top n
elements can be implemented by picking out the max element n
times:
def top[T](n:Int, iter:Iterable[T])(implicit ord: Ordering[T]): Iterable[T] = {
def partitionMax(acc: Iterable[T], it: Iterable[T]): Iterable[T] = {
val max = it.max(ord)
val (nextElems, rest) = it.partition(ord.gteq(_, max))
val maxElems = acc ++ nextElems
if (maxElems.size >= n || rest.isEmpty) maxElems.take(n)
else partitionMax(maxElems, rest)
}
if (iter.isEmpty) iter.take(0)
else partitionMax(iter.take(0), iter)
}
This does not sort the entire list and takes an Ordering
. I believe every method I call in partitionMax
is O(list size) and I only expect to call it n
times at most, so the overall efficiency for small n
will be proportional to the size of the iterator.
scala> top(5, List.range(1,1000000))
res13: Iterable[Int] = List(999999, 999998, 999997, 999996, 999995)
scala> top(5, List.range(1,1000000))(Ordering[Int].on(- _))
res14: Iterable[Int] = List(1, 2, 3, 4, 5)
You could also add a branch for when n
gets close to size of the iterable, and switch to iter.toList.sortBy(_.myAttr).take(n)
.
It does not return the type of collection provided, but you can look at How do I apply the enrich-my-library pattern to Scala collections? if this is a requirement.