简介
1、K-均值距离函数
1.1、欧式距离
欧式距离的计算公式 $$ d(x,y) = \sqrt{(x_1 - y_1)^2 + (x_2 - y_2)^2 + ... + (x_n - y_n)^2} $$
其中,x,y分别代表两个点,同时,两个点具有相同的维度:n。$x_1,x_2,...,x_n$代表点x的每个维度的值,$y_1,y_2,...,y_n$代表点y的各个维度的值。
1.2、欧氏距离的性质
假设有$p_1,p_2,p_{k}$3个点。
-
$d(p_1,p_2) \ge 0$
-
$d(p_i,p_i) = 0$
-
$d(p_i,p_j) = d(p_j,p_i)$
-
$d(p_i,p_j) \le d(p_i,p_k) + d(p_k,p_j)$
最后一个性质也说明了一个很常见的现象:两点间的距离,线段最短。
1.3、源码实现
import java.util.List;
/**
* 欧式距离计算
*/
public class EuclideanDistance {
public static double caculate(List<Double> p1, List<Double> p2){
double sum = 0.0;
int length = p1.size();
for (int i = 0; i < length; i++) {
sum += Math.pow(p1.get(i) - p2.get(i),2.0);
}
return Math.sqrt(sum);
}
}
2、形式化描述
K-均值算法是一个完成聚类分析的简单学习算法。K-均值聚类算法的目标是找出n项的最佳划分,也就是将n个对象划分到K个组中,是的一个组中的成员语气相应的质心(表示这个组)之间的总距离最小。采用形式化表示,目标就是将n项划分到K个集合$$ {S_i,i=1,2,...,K} $$ 中,使得簇内平方和或组内平方和(within-cluster sum of squares,WCSS)最小,WCSS定义为 $$ \min \sum_{j=1}^k \sum_{i=1}^n ||x_{i}^j - c_j|| $$
这里的$||x_i^j - c_j||$表示实体点与质心之间的距离。
3、MapReduce实现
3.1、数据集
如下所示,我们选用的二位数据集。
1.0,2.0
1.0,3.0
1.0,4.0
2.0,5.0
2.0,3.0
2.0,7.0
2.0,8.0
3.0,100.0
3.0,101.0
3.0,102.0
3.0,103.0
3.0,104.0
3.2、Mapper
package mapreduce;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class KMeansMapper extends Mapper<LongWritable, Text, IntWritable, Text> {
private List<List<Double>> centers = null;
// K
private int k = 0;
/**
* map 开始时调用一次。
* @param context
* @throws IOException
* @throws InterruptedException
*/
@Override
protected void setup(Context context) throws IOException, InterruptedException {
// config
String centerPath = context.getConfiguration().get("centerPath");
// 读取质心点信息
this.centers = KMeansUtil.getCenterFromFileSystem(centerPath);
// 获取K值(中心点个数)
k = centers.size();
System.out.println("当前的质心数据为:" + centers);
}
/**
* 1.每次读取一条要分类的条记录与中心做对比,归类到对应的中心
* 2.以中心ID为key,中心包含的记录为value输出(例如: 1 0.2---->1为聚类中心的ID,0.2为靠近聚类中心的某个值)
*/
@Override
protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
// 读取一行数据
List<Double> fields = KMeansUtil.textToList(value);
// 点维度
int dimension = fields.size();
double minDistance = Double.MAX_VALUE;
int centerIndex = 0;
// 依次取出K个中心点与当前读取的记录做计算
for (int i = 0; i < k; i++) {
double currentDistance = 0.0;
// 之所以跳过0,是因为1代表的是该点的ID,不纳入计算的范畴
for (int j = 1; j < dimension; j++) {
// 获取中心点
double centerPoint = Math.abs(centers.get(i).get(j));
// 当前需要计算的点
double field = Math.abs(fields.get(j));
// 计算欧氏距离
currentDistance += Math.pow((centerPoint - field) / (centerPoint + field), 2);
}
// 找出距离该记录最近的中心点的ID,记录最小值、该点的索引
if(currentDistance < minDistance){
minDistance = currentDistance;
centerIndex = i;
}
}
// 以中心点为key,原样输出,这样以该中心点为key的点都会作为一个簇在reducer端汇聚
context.write(new IntWritable(centerIndex),value);
}
}
3.3、Reuder
package mapreduce;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* 利用reduce归并功能以中心为key将记录归并在一起
*/
public class KMeansReducer extends Reducer<IntWritable, Text, NullWritable, Text>{
/**
* 1.K-V: Key为聚类中心的ID;value为该中心的记录集合;
* 2.计数所有记录元素的平均值,求出新的中心;KMeans算法的最终结果选取的质心点一般不是原数据集中的点
*/
@Override
protected void reduce(IntWritable key, Iterable<Text> values, Context context) throws IOException, InterruptedException {
List<List<Double>> result = new ArrayList<List<Double>>();
// 依次读取记录集,每行转化为一个List<Double>
for (Text value : values) {
result.add(KMeansUtil.textToList(value));
}
// 计算新的质心点:通过各个维的平均值
int dimension = result.get(0).size();
double[] averages = new double[dimension];
for (int i = 0; i < dimension; i++) {
double sum = 0.0;
int size = result.size();
for (int j = 0; j < size; j++) {
sum += result.get(j).get(i);
}
averages[i] = sum / size;
}
context.write(NullWritable.get(),new Text(Arrays.toString(averages).replace("[","").replace("]","")));
}
}
3.4、Driver
package mapreduce;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import java.io.IOException;
import java.util.List;
public class KMeansDriver {
public static void main(String[] args) throws Exception{
String dfs = "hdfs://192.168.35.128:9000";
// 存放中心点坐标值
String centerPath = dfs + "/kmeans/center/";
// 存放待处理数据
String dataPath = dfs + "/kmeans/kmeans_input_file.txt";
// 新中心点存放目录
String newCenterPath = dfs + "/kmeans/newCenter/";
// delta
double delta = 0.1D;
int count = 0;
final int K = 3;
// 选取初始的K个质心点
List<List<Double>> pick = KMeansUtil.pick(K, dfs + "/kmeans/kmeans_input_file.txt");
// 存储到结果集
KMeansUtil.writeCurrentKClusterToCenter(centerPath + "center.data",pick);
while(true){
++ count;
System.out.println(" 第 " + count + " 次计算 ");
run(dataPath, centerPath, newCenterPath);
System.out.println("计算迭代变化值");
// 比较新旧质点变化幅度
if(KMeansUtil.compareCenters(centerPath, newCenterPath,delta)){
System.out.println("迭代结束");
break;
}
}
/**
* 第 1 次计算
* 当前的质心数据为:[[1.0, 1.0], [1.0, 2.0], [1.0, 3.0]]
* task running status is : 1
* 计算迭代变化值
* 当前的质心点迭代变化值: 2125.9917355371904
* 第 2 次计算
* 当前的质心数据为:[[1.0, 1.0], [1.0, 2.0], [2.272727272727273, 49.09090909090909]]
* task running status is : 1
* 计算迭代变化值
* 当前的质心点迭代变化值: 2806.839601956485
* 第 3 次计算
* 当前的质心数据为:[[1.0, 1.0], [1.5714285714285714, 4.571428571428571], [3.0, 102.0]]
* task running status is : 1
* 计算迭代变化值
* 当前的质心点迭代变化值: 0.44274376417233585
* 第 4 次计算
* 当前的质心数据为:[[1.0, 1.5], [1.6666666666666667, 5.0], [3.0, 102.0]]
* task running status is : 1
* 计算迭代变化值
* 当前的质心点迭代变化值: 0.0
* 迭代结束
*/
}
public static void run(String dataPath, String centerPath, String newCenterPath) throws IOException, ClassNotFoundException, InterruptedException {
Configuration configuration = new Configuration();
configuration.set("centerPath", centerPath);
Job job = Job.getInstance(configuration);
job.setJarByClass(KMeansDriver.class);
job.setMapperClass(KMeansMapper.class);
job.setMapOutputKeyClass(IntWritable.class);
job.setMapOutputValueClass(Text.class);
job.setReducerClass(KMeansReducer.class);
job.setOutputKeyClass(NullWritable.class);
job.setOutputValueClass(Text.class);
FileInputFormat.setInputPaths(job,new Path(dataPath));
FileOutputFormat.setOutputPath(job,new Path(newCenterPath) );
System.out.println("task running status is : " + (job.waitForCompletion(true)? 1:0));
}
}
我们还可以写一个Combiner优化网络传输的流量,不过此处由于测试的缘故,就不写不是本章节主题的代码了。
另外,这几个类还使用了一个辅助工具类
package mapreduce;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.*;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.util.LineReader;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;
/**
* KMeans工具
*/
public class KMeansUtil {
public static FileSystem getFileSystem() throws URISyntaxException, IOException, InterruptedException {
// 获取一个具体的文件系统对象
return FileSystem.get(new URI("hdfs://192.168.35.128:9000"),new Configuration(),"root");
}
/**
* 在数据集中选取前K个点作为质心
* @param k
* @param filePath
* @return
*/
public static List<List<Double>> pick(int k, String filePath) throws Exception {
List<List<Double>> result = new ArrayList<List<Double>>();
Path path = new Path(filePath);
FileSystem fileSystem = getFileSystem();
FSDataInputStream open = fileSystem.open(path);
LineReader lineReader = new LineReader(open);
Text line = new Text();
// 读取每一行信息
while(lineReader.readLine(line) > 0 && k > 0){
List<Double> doubles = textToList(line);
result.add(doubles);
k = k - 1;
}
lineReader.close();
return result;
}
/**
* 将当前的结果写入数据中心
*/
public static void writeCurrentKClusterToCenter(String centerPath,List<List<Double>> data) throws Exception {
FSDataOutputStream out = getFileSystem().create(new Path(centerPath));
for (List<Double> d : data) {
String str = d.toString();
out.write(str.replace("[","").replace("]","\n").getBytes());
}
out.close();
}
/**
* 从数据中心获取质心点数据
* @param filePath 路径
* @return 质心数据
*/
public static List<List<Double>> getCenterFromFileSystem(String filePath) throws IOException {
List<List<Double>> result = new ArrayList<List<Double>>();
Path path = new Path(filePath);
Configuration configuration = new Configuration();
FileSystem fileSystem = null;
try {
fileSystem = getFileSystem();
} catch (Exception e) {
e.printStackTrace();
}
FileStatus[] listFiles = fileSystem.listStatus(path);
for (FileStatus file : listFiles) {
FSDataInputStream open = fileSystem.open(file.getPath());
LineReader lineReader = new LineReader(open, configuration);
Text line = new Text();
// 读取每一行信息
while(lineReader.readLine(line) > 0){
List<Double> doubles = textToList(line);
result.add(doubles);
}
}
return result;
}
/**
* 将Text转化为数组
* @param text
* @return
*/
public static List<Double> textToList(Text text){
List<Double> list = new ArrayList<Double>();
String[] split = text.toString().split(",");
for (int i = 0; i < split.length; i++) {
list.add(Double.parseDouble(split[i]));
}
return list;
}
/**
* 比较新旧数据点的变化情况
* @return
* @throws Exception
*/
public static boolean compareCenters(String center, String newCenter, double delta) throws Exception{
List<List<Double>> oldCenters = getCenterFromFileSystem(center);
List<List<Double>> newCenters = getCenterFromFileSystem(newCenter);
// 质心点数
int size = oldCenters.size();
// 维度
int fieldSize = oldCenters.get(0).size();
double distance = 0.0;
for (int i = 0; i < size; i++) {
for (int j = 0; j < fieldSize; j++) {
double p1 = Math.abs(oldCenters.get(i).get(j));
double p2 = Math.abs(newCenters.get(i).get(j));
// this is used euclidean distance.
distance += Math.pow(p1 - p2, 2);
}
}
System.out.println("当前的质心点迭代变化值: " + distance);
// 在区间内
if(distance <= delta){
return true;
}else{
Path centerPath = new Path(center);
Path newCenterPath = new Path(newCenter);
FileSystem fs = getFileSystem();
// 删除当前质点文件
fs.delete(centerPath,true );
// 将新质点文件结果移动到当前质点文件
fs.rename(newCenterPath,centerPath);
}
return false;
}
}
可以看到,我们的K=3,并且选择的是数据集中的前三个点作为初始迭代的质心点。当然,更好的算法应该是从数据集中随机选取3个点或者以贴合业务的选取方式选取初始点,从算法中我们可以了解到,初始点的选择在一定迭代次数内是对结果有很大的影响的。
3.5、绘图
最终,我们得到的结果如下,其中的红点即为质心点
来源:oschina
链接:https://my.oschina.net/u/3091870/blog/3023599