How to find the index of the maximum value in a vector column?

…衆ロ難τιáo~ 提交于 2021-01-27 07:41:15

问题


I have a Spark DataFrame with the following structure:

root
|-- distribution: vector (nullable = true)

+--------------------+
|   topicDistribution|
+--------------------+
|     [0.1, 0.2]     |
|     [0.3, 0.2]     |
|     [0.5, 0.2]     |
|     [0.1, 0.7]     |
|     [0.1, 0.8]     |
|     [0.1, 0.9]     |
+--------------------+

My question is: How to add a column with the index of the maximum value for each row?

It should be something like this:

root
|-- distribution: vector (nullable = true)
|-- max_index: integer (nullable = true)

+--------------------+-----------+
|   topicDistribution| max_index |
+--------------------+-----------+
|     [0.1, 0.2]     |   1       | 
|     [0.3, 0.2]     |   0       | 
|     [0.5, 0.2]     |   0       | 
|     [0.1, 0.7]     |   1       | 
|     [0.1, 0.8]     |   1       | 
|     [0.1, 0.9]     |   1       | 
+--------------------+-----------+

Thanks a lot

I tried the following method but I got an error:

import org.apache.spark.sql.functions.udf

val func = udf( (x: Vector[Double]) => x.indices.maxBy(x) )

df.withColumn("max_idx",func(($"topicDistribution"))).show()

Error says:

Exception in thread "main" org.apache.spark.sql.AnalysisException: 
cannot resolve 'UDF(topicDistribution)' due to data type mismatch: 
argument 1 requires array<double> type, however, '`topicDistribution`' 
is of vector type.;;

回答1:


// create some sample data:
import org.apache.spark.mllib.linalg.{Vectors,Vector}
case class myrow(topics:Vector)

 val rdd = sc.parallelize(Array(myrow(Vectors.dense(0.1,0.2)),myrow(Vectors.dense(0.6,0.2))))
val mydf = sqlContext.createDataFrame(rdd)
mydf.show()
+----------+
|    topics|
+----------+
|[0.1, 0.2]|
|[0.6, 0.2]|
+----------+

// build the udf
import org.apache.spark.sql.functions.udf
val func = udf( (x:Vector) => x.toDense.values.toSeq.indices.maxBy(x.toDense.values) )


mydf.withColumn("max_idx",func($"topics")).show()
+----------+-------+
|    topics|max_idx|
+----------+-------+
|[0.1, 0.2]|      1|
|[0.6, 0.2]|      0|
+----------+-------+

// note: you might have to change the UDF to be Vector instead of Seq for your particular use-case //edited to use Vector instead of Seq as you original question and your comment asked




回答2:


NOTE: The solution may not be the best performance-wise but just shows the other approach to tackle the problem (and shows how rich Spark SQL's Dataset API is).


vector is from Spark MLlib's VectorUDT so let me create a sample dataset first.

val input = Seq((0.1, 0.2), (0.3, 0.2)).toDF
import org.apache.spark.ml.feature.VectorAssembler
val vecAssembler = new VectorAssembler()
  .setInputCols(Array("_1", "_2"))
  .setOutputCol("distribution")
val ds = vecAssembler.transform(input).select("distribution")
scala> ds.printSchema
root
 |-- distribution: vector (nullable = true)

The schema looks exactly like yours.


Let's change the type from VectorUDT to the regular Array[Double].

import org.apache.spark.ml.linalg.Vector
val arrays = ds
  .map { r => r.getAs[Vector](0).toArray }
  .withColumnRenamed("value", "distribution")
scala> arrays.printSchema
root
 |-- distribution: array (nullable = true)
 |    |-- element: double (containsNull = false)

With arrays you could use posexplode to index the elements in arrays, groupBy to max over positions and join for a solution.

val pos = arrays.select($"*", posexplode($"distribution"))
val max_cols = pos
  .groupBy("distribution")
  .agg(max("col") as "max_col")
val solution = pos
  .join(max_cols, "distribution")
  .filter($"col" === $"max_col")
  .select("distribution", "pos")
scala> solution.show
+------------+---+
|distribution|pos|
+------------+---+
|  [0.1, 0.2]|  1|
|  [0.3, 0.2]|  0|
+------------+---+


来源:https://stackoverflow.com/questions/47560366/how-to-find-the-index-of-the-maximum-value-in-a-vector-column

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!