Prepare data for MultilayerPerceptronClassifier in scala

安稳与你 提交于 2019-12-01 09:09:16

The source of your problems is a wrong definition of layers. When you use

val layers = Array[Int](0, 0, 0, 0)

it means you want a network with zero nodes in each layer which simply doesn't make sense. Generally speaking number of neurons in the input layer should be equal to the number of features and each hidden layer should contain at least one neuron.

Lets recreate your data simpling your code on the way:

import org.apache.spark.sql.functions.col

val df = sc.parallelize(Seq(
  ("0", "0", "0"), ("0", "0", "1628859.542")
)).toDF("gst_id_matched", "ip_crowding", "lat_long_dist")

Convert all columns to doubles:

val numeric = df
  .select(df.columns.map(c => col(c).cast("double").alias(c)): _*)
  .withColumnRenamed("gst_id_matched", "label")

Assemble features:

import org.apache.spark.ml.feature.VectorAssembler

val assembler = new VectorAssembler()
  .setInputCols(Array("ip_crowding","lat_long_dist"))
  .setOutputCol("features")

val data = assembler.transform(numeric)
data.show

// +-----+-----------+-------------+-----------------+
// |label|ip_crowding|lat_long_dist|         features|
// +-----+-----------+-------------+-----------------+
// |  0.0|        0.0|          0.0|        (2,[],[])|
// |  0.0|        0.0|  1628859.542|[0.0,1628859.542]|
// +-----+-----------+-------------+-----------------+

Train and test network:

import org.apache.spark.ml.classification.MultilayerPerceptronClassifier

val layers = Array[Int](2, 3, 5, 3) // Note 2 neurons in the input layer
val trainer = new MultilayerPerceptronClassifier()
  .setLayers(layers)
  .setBlockSize(128)
  .setSeed(1234L)
  .setMaxIter(100)

val model = trainer.fit(data)
model.transform(data).show

// +-----+-----------+-------------+-----------------+----------+
// |label|ip_crowding|lat_long_dist|         features|prediction|
// +-----+-----------+-------------+-----------------+----------+
// |  0.0|        0.0|          0.0|        (2,[],[])|       0.0|
// |  0.0|        0.0|  1628859.542|[0.0,1628859.542]|       0.0|
// +-----+-----------+-------------+-----------------+----------+
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!