大数据量下寻找相邻单词的数量

折月煮酒 提交于 2020-02-26 23:28:12

这题目和Leetcode中的一些搜索题目有点类似。

想处理的问题是:统计一个单词相邻前后两位的数量,如有w1,w2,w3,w4,w5,w6,则:

file

最终要输出为(word,neighbor,frequency)。

我们用五种方法实现:

  • MapReduce
  • Spark
  • Spark SQL的方法
  • Scala方法
  • Scala版Spark SQL

MapReduce

file

//map函数
 @Override
    protected void map(LongWritable key, Text value, Context context)
            throws IOException, InterruptedException {

        String[] tokens = StringUtils.split(value.toString(), " ");
        //String[] tokens = StringUtils.split(value.toString(), "\\s+");
        if ((tokens == null) || (tokens.length < 2)) {
            return;
        }
        //计算相邻两个单词的计算规则
        for (int i = 0; i < tokens.length; i++) {
            tokens[i] = tokens[i].replaceAll("\\W+", "");

            if (tokens[i].equals("")) {
                continue;
            }

            pair.setWord(tokens[i]);

            int start = (i - neighborWindow < 0) ? 0 : i - neighborWindow;
            int end = (i + neighborWindow >= tokens.length) ? tokens.length - 1 : i + neighborWindow;
            for (int j = start; j <= end; j++) {
                if (j == i) {
                    continue;
                }
                pair.setNeighbor(tokens[j].replaceAll("\\W", ""));
                context.write(pair, ONE);
            }
            //
            pair.setNeighbor("*");
            totalCount.set(end - start);
            context.write(pair, totalCount);
        }
    }

//reduce函数
 @Override
    protected void reduce(PairOfWords key, Iterable<intwritable> values, Context context)
            throws IOException, InterruptedException {
        //等于*表示为单词本身,它的count为totalCount
        if (key.getNeighbor().equals("*")) {
            if (key.getWord().equals(currentWord)) {
                totalCount += totalCount + getTotalCount(values);
            } else {
                currentWord = key.getWord();
                totalCount = getTotalCount(values);
            }
        } else {
            //其它的则为单次的word,需要通过getTotalCount获得相加
            int count = getTotalCount(values);
            relativeCount.set((double) count / totalCount);
            context.write(key, relativeCount);
        }

    }

Spark

public static void main(String[] args) {
        if (args.length &lt; 3) {
            System.out.println("Usage: RelativeFrequencyJava <neighbor-window> <input-dir> <output-dir>");
            System.exit(1);
        }

        SparkConf sparkConf = new SparkConf().setAppName("RelativeFrequency");
        JavaSparkContext sc = new JavaSparkContext(sparkConf);

        int neighborWindow = Integer.parseInt(args[0]);
        String input = args[1];
        String output = args[2];

        final Broadcast<integer> brodcastWindow = sc.broadcast(neighborWindow);

        JavaRDD<string> rawData = sc.textFile(input);

        /*
         * Transform the input to the format: (word, (neighbour, 1))
         */
        JavaPairRDD<string, tuple2<string, integer>&gt; pairs = rawData.flatMapToPair(
                new PairFlatMapFunction<string, string, tuple2<string, integer>&gt;() {
            private static final long serialVersionUID = -6098905144106374491L;

            @Override
            public java.util.Iterator<scala.tuple2<string, scala.tuple2<string, integer>&gt;&gt; call(String line) throws Exception {
                List<tuple2<string, tuple2<string, integer>&gt;&gt; list = new ArrayList<tuple2<string, tuple2<string, integer>&gt;&gt;();
                String[] tokens = line.split("\\s");
                for (int i = 0; i &lt; tokens.length; i++) {
                    int start = (i - brodcastWindow.value() &lt; 0) ? 0 : i - brodcastWindow.value();
                    int end = (i + brodcastWindow.value() &gt;= tokens.length) ? tokens.length - 1 : i + brodcastWindow.value();
                    for (int j = start; j &lt;= end; j++) {
                        if (j != i) {
                            list.add(new Tuple2<string, tuple2<string, integer>&gt;(tokens[i], new Tuple2<string, integer>(tokens[j], 1)));
                        } else {
                            // do nothing
                            continue;
                        }
                    }
                }
                return list.iterator();
            }
        }
        );

        // (word, sum(word))
        //PairFunction<t, k, v> T =&gt; Tuple2<k, v>
        JavaPairRDD<string, integer> totalByKey = pairs.mapToPair(

                new PairFunction<tuple2<string, tuple2<string, integer>&gt;, String, Integer&gt;() {
            private static final long serialVersionUID = -213550053743494205L;

            @Override
            public Tuple2<string, integer> call(Tuple2<string, tuple2<string, integer>&gt; tuple) throws Exception {
                return new Tuple2<string, integer>(tuple._1, tuple._2._2);
            }
        }).reduceByKey(
                        new Function2<integer, integer, integer>() {
                    private static final long serialVersionUID = -2380022035302195793L;

                    @Override
                    public Integer call(Integer v1, Integer v2) throws Exception {
                        return (v1 + v2);
                    }
                });

        JavaPairRDD<string, iterable<tuple2<string, integer>&gt;&gt; grouped = pairs.groupByKey();

        // (word, (neighbour, 1)) -&gt; (word, (neighbour, sum(neighbour)))
        //flatMapValues至少对value进行操作,但是不改变key的顺序
        JavaPairRDD<string, tuple2<string, integer>&gt; uniquePairs = grouped.flatMapValues(
                //Function<t1, r> -&gt; R call(T1 v1)
                new Function<iterable<tuple2<string, integer>&gt;, Iterable<tuple2<string, integer>&gt;&gt;() {
            private static final long serialVersionUID = 5790208031487657081L;
            
            @Override
            public Iterable<tuple2<string, integer>&gt; call(Iterable<tuple2<string, integer>&gt; values) throws Exception {
                Map<string, integer> map = new HashMap&lt;&gt;();
                List<tuple2<string, integer>&gt; list = new ArrayList&lt;&gt;();
                Iterator<tuple2<string, integer>&gt; iterator = values.iterator();
                while (iterator.hasNext()) {
                    Tuple2<string, integer> value = iterator.next();
                    int total = value._2;
                    if (map.containsKey(value._1)) {
                        total += map.get(value._1);
                    }
                    map.put(value._1, total);
                }
                for (Map.Entry<string, integer> kv : map.entrySet()) {
                    list.add(new Tuple2<string, integer>(kv.getKey(), kv.getValue()));
                }
                return list;
            }
        });

        // (word, ((neighbour, sum(neighbour)), sum(word)))
        JavaPairRDD<string, tuple2<tuple2<string, integer>, Integer&gt;&gt; joined = uniquePairs.join(totalByKey);

        // ((key, neighbour), sum(neighbour)/sum(word))
        JavaPairRDD<tuple2<string, string>, Double&gt; relativeFrequency = joined.mapToPair(
                new PairFunction<tuple2<string, tuple2<tuple2<string, integer>, Integer&gt;&gt;, Tuple2<string, string>, Double&gt;() {
            private static final long serialVersionUID = 3870784537024717320L;

            @Override
            public Tuple2<tuple2<string, string>, Double&gt; call(Tuple2<string, tuple2<tuple2<string, integer>, Integer&gt;&gt; tuple) throws Exception {
                return new Tuple2<tuple2<string, string>, Double&gt;(new Tuple2<string, string>(tuple._1, tuple._2._1._1), ((double) tuple._2._1._2 / tuple._2._2));
            }
        });

        // For saving the output in tab separated format
        // ((key, neighbour), relative_frequency)
        //将结果转换成一个String
        JavaRDD<string> formatResult_tab_separated = relativeFrequency.map(
                new Function<tuple2<tuple2<string, string>, Double&gt;, String&gt;() {
            private static final long serialVersionUID = 7312542139027147922L;

            @Override
            public String call(Tuple2<tuple2<string, string>, Double&gt; tuple) throws Exception {
                return tuple._1._1 + "\t" + tuple._1._2 + "\t" + tuple._2;
            }
        });

        // save output
        formatResult_tab_separated.saveAsTextFile(output);

        // done
        sc.close();

    }

Spark SQL


 public static void main(String[] args) {
        if (args.length &lt; 3) {
            System.out.println("Usage: SparkSQLRelativeFrequency <neighbor-window> <input-dir> <output-dir>");
            System.exit(1);
        }

        SparkConf sparkConf = new SparkConf().setAppName("SparkSQLRelativeFrequency");
        //创建SparkSQL需要的SparkSession
        SparkSession spark = SparkSession
                .builder()
                .appName("SparkSQLRelativeFrequency")
                .config(sparkConf)
                .getOrCreate();

        JavaSparkContext sc = new JavaSparkContext(spark.sparkContext());
        int neighborWindow = Integer.parseInt(args[0]);
        String input = args[1];
        String output = args[2];

        final Broadcast<integer> brodcastWindow = sc.broadcast(neighborWindow);

        /*
         *注册一个Schema表,这个frequency等会要用
         * Schema (word, neighbour, frequency)
         */
        StructType rfSchema = new StructType(new StructField[]{
            new StructField("word", DataTypes.StringType, false, Metadata.empty()),
            new StructField("neighbour", DataTypes.StringType, false, Metadata.empty()),
            new StructField("frequency", DataTypes.IntegerType, false, Metadata.empty())});

        JavaRDD<string> rawData = sc.textFile(input);

        /*
         * Transform the input to the format: (word, (neighbour, 1))
         */
        JavaRDD<row> rowRDD = rawData
                .flatMap(new FlatMapFunction<string, row>() {
                    private static final long serialVersionUID = 5481855142090322683L;

                    @Override
                    public Iterator<row> call(String line) throws Exception {
                        List<row> list = new ArrayList&lt;&gt;();
                        String[] tokens = line.split("\\s");
                        for (int i = 0; i &lt; tokens.length; i++) {
                            int start = (i - brodcastWindow.value() &lt; 0) ? 0
                                    : i - brodcastWindow.value();
                            int end = (i + brodcastWindow.value() &gt;= tokens.length) ? tokens.length - 1
                                    : i + brodcastWindow.value();
                            for (int j = start; j &lt;= end; j++) {
                                if (j != i) {
                                    list.add(RowFactory.create(tokens[i], tokens[j], 1));
                                } else {
                                    // do nothing
                                    continue;
                                }
                            }
                        }
                        return list.iterator();
                    }
                });
        //创建DataFrame
        Dataset<row> rfDataset = spark.createDataFrame(rowRDD, rfSchema);
        //将rfDataset转成一个table,可以进行查询
        rfDataset.createOrReplaceTempView("rfTable");

        String query = "SELECT a.word, a.neighbour, (a.feq_total/b.total) rf "
                + "FROM (SELECT word, neighbour, SUM(frequency) feq_total FROM rfTable GROUP BY word, neighbour) a "
                + "INNER JOIN (SELECT word, SUM(frequency) as total FROM rfTable GROUP BY word) b ON a.word = b.word";
        Dataset<row> sqlResult = spark.sql(query);

        sqlResult.show(); // print first 20 records on the console
        sqlResult.write().parquet(output + "/parquetFormat"); // saves output in compressed Parquet format, recommended for large projects.
        sqlResult.rdd().saveAsTextFile(output + "/textFormat"); // to see output via cat command

        // done
        sc.close();
        spark.stop();

    }


Scala

def main(args: Array[String]): Unit = {

    if (args.size &lt; 3) {
      println("Usage: RelativeFrequency <neighbor-window> <input-dir> <output-dir>")
      sys.exit(1)
    }

    val sparkConf = new SparkConf().setAppName("RelativeFrequency")
    val sc = new SparkContext(sparkConf)

    val neighborWindow = args(0).toInt
    val input = args(1)
    val output = args(2)

    val brodcastWindow = sc.broadcast(neighborWindow)

    val rawData = sc.textFile(input)

    /* 
     * Transform the input to the format:
     * (word, (neighbour, 1))
     */
    val pairs = rawData.flatMap(line =&gt; {
      val tokens = line.split("\\s")
      for {
        i &lt;- 0 until tokens.length
        start = if (i - brodcastWindow.value &lt; 0) 0 else i - brodcastWindow.value
        end = if (i + brodcastWindow.value &gt;= tokens.length) tokens.length - 1 else i + brodcastWindow.value
        j &lt;- start to end if (j != i)
        //用yield来收集转换之后的函数(word, (neighbour, 1))
      } yield (tokens(i), (tokens(j), 1))
    })

    // (word, sum(word))
    val totalByKey = pairs.map(t =&gt; (t._1, t._2._2)).reduceByKey(_ + _)

    val grouped = pairs.groupByKey()

    // (word, (neighbour, sum(neighbour)))
    val uniquePairs = grouped.flatMapValues(_.groupBy(_._1).mapValues(_.unzip._2.sum))
    //用join函数把两个RDD连接起来
    // (word, ((neighbour, sum(neighbour)), sum(word)))
    val joined = uniquePairs join totalByKey

    // ((key, neighbour), sum(neighbour)/sum(word))
    val relativeFrequency = joined.map(t =&gt; {
      ((t._1, t._2._1._1), (t._2._1._2.toDouble / t._2._2.toDouble))
    })

    // For saving the output in tab separated format
    // ((key, neighbour), relative_frequency)
    val formatResult_tab_separated = relativeFrequency.map(t =&gt; t._1._1 + "\t" + t._1._2 + "\t" + t._2)
    formatResult_tab_separated.saveAsTextFile(output)

    // done
    sc.stop()
  }

Scala版Spark SQL

def main(args: Array[String]): Unit = {

    if (args.size &lt; 3) {
      println("Usage: SparkSQLRelativeFrequency <neighbor-window> <input-dir> <output-dir>")
      sys.exit(1)
    }

    val sparkConf = new SparkConf().setAppName("SparkSQLRelativeFrequency")

    val spark = SparkSession
      .builder()
      .config(sparkConf)
      .getOrCreate()
    val sc = spark.sparkContext

    val neighborWindow = args(0).toInt
    val input = args(1)
    val output = args(2)

    val brodcastWindow = sc.broadcast(neighborWindow)

    val rawData = sc.textFile(input)

    /*
    * Schema
    * (word, neighbour, frequency)
    */
    val rfSchema = StructType(Seq(
      StructField("word", StringType, false),
      StructField("neighbour", StringType, false),
      StructField("frequency", IntegerType, false)))

    /* 
     * Transform the input to the format:
     * Row(word, neighbour, 1)
     */
    //转换成StructType中要求的格式
    val rowRDD = rawData.flatMap(line =&gt; {
      val tokens = line.split("\\s")
      for {
        i &lt;- 0 until tokens.length
        //正常的计算规则,与MapReduce有区别
        start = if (i - brodcastWindow.value &lt; 0) 0 else i - brodcastWindow.value
        end = if (i + brodcastWindow.value &gt;= tokens.length) tokens.length - 1 else i + brodcastWindow.value
        j &lt;- start to end if (j != i)
      } yield Row(tokens(i), tokens(j), 1)
    })

    val rfDataFrame = spark.createDataFrame(rowRDD, rfSchema)
    //创建rfTable表
    rfDataFrame.createOrReplaceTempView("rfTable")

    import spark.sql

    val query = "SELECT a.word, a.neighbour, (a.feq_total/b.total) rf " +
      "FROM (SELECT word, neighbour, SUM(frequency) feq_total FROM rfTable GROUP BY word, neighbour) a " +
      "INNER JOIN (SELECT word, SUM(frequency) as total FROM rfTable GROUP BY word) b ON a.word = b.word"

    val sqlResult = sql(query)
    sqlResult.show() // print first 20 records on the console
    sqlResult.write.save(output + "/parquetFormat") // saves output in compressed Parquet format, recommended for large projects.
    sqlResult.rdd.saveAsTextFile(output + "/textFormat") // to see output via cat command

    // done
    spark.stop()

  }

声明:本号所有文章除特殊注明,都为原创,公众号读者拥有优先阅读权,未经作者本人允许不得转载,否则追究侵权责任。

关注我的公众号,后台回复【JAVAPDF】获取200页面试题! 5万人关注的大数据成神之路,不来了解一下吗? 5万人关注的大数据成神之路,真的不来了解一下吗? 5万人关注的大数据成神之路,确定真的不来了解一下吗?

欢迎您关注《大数据成神之路》

大数据技术与架构</output-dir></input-dir></neighbor-window></output-dir></input-dir></neighbor-window></row></row></row></row></string,></row></string></integer></output-dir></input-dir></neighbor-window></tuple2<string,></tuple2<tuple2<string,></string></string,></tuple2<string,></string,></tuple2<string,></string,></tuple2<string,></tuple2<string,></string,></string,></string,></string,></tuple2<string,></tuple2<string,></string,></tuple2<string,></tuple2<string,></tuple2<string,></iterable<tuple2<string,></t1,></string,></string,></integer,></string,></string,></string,></tuple2<string,></string,></k,></t,></string,></string,></tuple2<string,></tuple2<string,></scala.tuple2<string,></string,></string,></string></integer></output-dir></input-dir></neighbor-window></intwritable>

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!