Spark column wise word count

℡╲_俬逩灬. 提交于 2019-12-04 13:03:51
Daniel Darabos

Let me simplify your question a bit. (A lot actually.) We have an RDD[(Int, String)] and we want to find the top 10 most common Strings for each Int (which are all in the 0–100 range).

Instead of sorting, as in your example, it is more efficient to use the Spark built-in RDD.top(n) method. Its run-time is linear in the size of the data, and requires moving much less data around than a sort.

Consider the implementation of top in RDD.scala. You want to do the same, but with one priority queue (heap) per Int key. The code becomes fairly complex:

import org.apache.spark.util.BoundedPriorityQueue // Pretend it's not private.

def top(n: Int, rdd: RDD[(Int, String)]): Map[Int, Iterable[String]] = {
  // A heap that only keeps the top N values, so it has bounded size.
  type Heap = BoundedPriorityQueue[(Long, String)]
  // Get the word counts.
  val counts: RDD[[(Int, String), Long)] =
    rdd.map(_ -> 1L).reduceByKey(_ + _)
  // In each partition create a column -> heap map.
  val perPartition: RDD[Map[Int, Heap]] =
    counts.mapPartitions { items =>
      val heaps =
        collection.mutable.Map[Int, Heap].withDefault(i => new Heap(n))
      for (((k, v), count) <- items) {
        heaps(k) += count -> v
      }
      Iterator.single(heaps)
    }
  // Merge the per-partition heap maps into one.
  val merged: Map[Int, Heap] =
    perPartition.reduce { (heaps1, heaps2) =>
      val heaps =
        collection.mutable.Map[Int, Heap].withDefault(i => new Heap(n))
      for ((k, heap) <- heaps1.toSeq ++ heaps2.toSeq) {
        for (cv <- heap) {
          heaps(k) += cv
        }
      }
      heaps
    }
  // Discard counts, return just the top strings.
  merged.mapValues(_.map { case(count, value) => value })
}

This is efficient, but made painful because we need to work with multiple columns at the same time. It would be way easier to have one RDD per column and just call rdd.top(10) on each.

Unfortunately the naive way to split up the RDD into N smaller RDDs does N passes:

def split(together: RDD[(Int, String)], columns: Int): Seq[RDD[String]] = {
  together.cache // We will make N passes over this RDD.
  (0 until columns).map {
    i => together.filter { case (key, value) => key == i }.values
  }
}

A more efficient solution could be to write out the data into separate files by key, then load it back into separate RDDs. This is discussed in Write to multiple outputs by key Spark - one Spark job.

Thanks for @Daniel Darabos's answer. But there are some mistakes.

  1. mixed use of Map and collection.mutable.Map

  2. withDefault((i: Int) => new Heap(n)) do not create a new Heap when you set heaps(k) += count -> v

  3. mix uasage of parentheses

Here is the modified code:

//import org.apache.spark.util.BoundedPriorityQueue // Pretend it's not private. copy to your own folder and import it
import org.apache.log4j.{Level, Logger}
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}


object BoundedPriorityQueueTest {

  //  https://stackoverflow.com/questions/28166190/spark-column-wise-word-count
  def top(n: Int, rdd: RDD[(Int, String)]): Map[Int, Iterable[String]] = {
    // A heap that only keeps the top N values, so it has bounded size.
    type Heap = BoundedPriorityQueue[(Long, String)]
    // Get the word counts.
    val counts: RDD[((Int, String), Long)] =
    rdd.map(_ -> 1L).reduceByKey(_ + _)
    // In each partition create a column -> heap map.
    val perPartition: RDD[collection.mutable.Map[Int, Heap]] =
    counts.mapPartitions { items =>
      val heaps =
        collection.mutable.Map[Int, Heap]() // .withDefault((i: Int) => new Heap(n))
      for (((k, v), count) <- items) {
        println("\n---")
        println("before add " + ((k, v), count) + ", the map is: ")
        println(heaps)
        if (!heaps.contains(k)) {
          println("not contains key " + k)
          heaps(k) = new Heap(n)
          println(heaps)
        }
        heaps(k) += count -> v
        println("after add " + ((k, v), count) + ", the map is: ")
        println(heaps)

      }
      println(heaps)
      Iterator.single(heaps)
    }
    // Merge the per-partition heap maps into one.
    val merged: collection.mutable.Map[Int, Heap] =
    perPartition.reduce { (heaps1, heaps2) =>
      val heaps =
        collection.mutable.Map[Int, Heap]() //.withDefault((i: Int) => new Heap(n))
      println(heaps)
      for ((k, heap) <- heaps1.toSeq ++ heaps2.toSeq) {
        for (cv <- heap) {
          heaps(k) += cv
        }
      }
      heaps
    }
    // Discard counts, return just the top strings.
    merged.mapValues(_.map { case (count, value) => value }).toMap
  }

  def main(args: Array[String]): Unit = {
    Logger.getRootLogger().setLevel(Level.FATAL) //http://stackoverflow.com/questions/27781187/how-to-stop-messages-displaying-on-spark-console
    val conf = new SparkConf().setAppName("word count").setMaster("local[1]")
    val sc = new SparkContext(conf)
    sc.setLogLevel("WARN") //http://stackoverflow.com/questions/27781187/how-to-stop-messages-displaying-on-spark-console


    val words = sc.parallelize(List((1, "s11"), (1, "s11"), (1, "s12"), (1, "s13"), (2, "s21"), (2, "s22"), (2, "s22"), (2, "s23")))
    println("# words:" + words.count())

    val result = top(1, words)

    println("\n--result:")
    println(result)
    sc.stop()

    print("DONE")
  }

}
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!