EM算法-混合高斯模型

安稳与你 提交于 2020-03-11 17:16:24

最近在学EM算法,看到大佬写的博客很好,我仅转载:https://blog.csdn.net/coldnoble/article/details/41625911?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522158390483919725211958727%2522%252C%2522scm%2522%253A%252220140713.130056874…%2522%257D&request_id=158390483919725211958727&biz_id=0&utm_source=distribute.pc_search_result.none-task

最近在看李航的《统计学习方法》一书,关于EM算法部分收集了些资料进行了学习,做了些混合高斯的模拟,下面分三个部分介绍下相关内容:1)EM算法原理,2)混合高斯推导,3)相关代码和结果

一、EM算法原理

EM算法推导中一个重要的概念是Jensen不等式其表述为:如果为凸函数(),则有当且仅当的时候不等式两边等号才成立。

如果概率模型只针对观测样本,那么根据的观测值,可以通过极大似然或贝叶斯估计法估计其参数。但是,如果概率模型不仅包含观测样本,还含有隐变量(无法观测其值),这时就需要EM算法来估计隐变量

观测样本的模型参数,也可以认为EM算法是含有隐变量的极大似然估计法。

观测数据,隐变量,用表示样本的隐变量分布


则似然函数可以表示为


可以看成是的期望由Jensen不等式可得


这里log函数是凹函数,故有

如果需要确定下界,则需要等号成立,则有为常数则有




已知了后就可以调整来优化下界



因为每一次迭代均是极大值故有

因此迭代过程中单调递增,故EM算法收敛


二、混合高斯

高斯混合概率分布如下



隐变量为,即第个观测值是否来自第个高斯模型的概率,取0或1,也有的资料写成,其实想表达的意思是一样的。


                                         


这里即为第一部分推导中的


目标函数

             

利用极值法有:

                  (1)

考虑约束条件  由拉格朗日乘数法

                     (2)



                                   (3)

求解(1)(2)(3)即可得到混合高斯参数的值


参数的求解如下:






三、以下是混合高斯的实验,opencv2.4.9

  1. // HelloOpenCV.cpp : Defines the entry point for the console application.
  2. //
  3. #include "stdafx.h"
  4. #include <opencv2/opencv.hpp>
  5. #include <iostream>
  6. #include <math.h>
  7. #define COUNT 4
  8. #define HIST_SIZE 256
  9. #define EPS 1e-32
  10. using namespace std;
  11. using namespace cv;
  12. void Gaussian_fun(double *u, double *delta, double *p, unsigned char grey)
  13. {
  14. int i;
  15. for (i = 0; i < COUNT; i++)
  16. {
  17. p[i] = (0.39894228*exp(-pow(grey - u[i], 2) / 2 / (delta[i] + EPS))) / sqrt(delta[i] + EPS);
  18. }
  19. }
  20. void Drawhist(CvHistogram *hist, IplImage *hist_img, CvScalar scalar)
  21. {
  22. int i;
  23. float MaxValue;
  24. double probality, probality_old;
  25. CvSize size = cvGetSize(hist_img);
  26. cvGetMinMaxHistValue(hist, 0, &MaxValue, 0);
  27. probality_old = cvGetReal1D(hist->bins, 0);
  28. probality_old = cvRound(HIST_SIZE*probality_old / size.height/size.width);
  29. for (i = 1; i < 256; i++)
  30. {
  31. probality = cvGetReal1D(hist->bins, i);
  32. probality = cvRound(HIST_SIZE*probality / size.height/size.width);
  33. cvLine(hist_img, cvPoint(i - 1, 1.5*(128 - probality_old)), cvPoint(i, 1.5*(128 - probality)), scalar);
  34. probality_old = probality;
  35. }
  36. }
  37. double Distance(void *new_mat, void *old_mat)
  38. {
  39. CvMat *mat1 = cvCreateMat(4, 1, CV_32FC1);
  40. CvMat *mat2 = cvCreateMat(4, 1, CV_32FC1);
  41. cvSetData(mat1, new_mat, mat1->step);
  42. cvSetData(mat1, new_mat, mat1->step);
  43. return cvNorm(mat1, mat2, CV_L2, 0);
  44. }
  45. void EM_GMM(IplImage *img, IplImage *hist_img)
  46. {
  47. int i, j, iter = 0;
  48. unsigned char grey;
  49. double alpha[COUNT] = { 0.25, 0.25, 0.25, 0.25 }, delta[COUNT] = { 20, 20, 20, 20 }, u[COUNT] = { 50, 100, 150, 200 };
  50. double u_old[COUNT] = { 0 }, alpha_old[COUNT] = { 0 }, delta_old[COUNT] = { 0 };
  51. double p[COUNT] = { 0 };
  52. CvSize size = cvGetSize(img);
  53. double sum_p, sum_gamma[COUNT] = { 0 }, sum_gammay[COUNT] = { 0 }, sum_gammayy[COUNT] = { 0 }, gamma[COUNT] = { 0 };
  54. while (iter < 1000 && Distance(alpha, alpha_old) > 0.01 && Distance(u, u_old) > 0.01 && Distance(delta, delta_old) > 0.01)
  55. {
  56. memset(gamma, 0, sizeof(gamma));
  57. memset(sum_gamma, 0, sizeof(sum_gamma));
  58. memset(sum_gammay, 0, sizeof(sum_gamma));
  59. memset(sum_gammayy, 0, sizeof(sum_gamma));
  60. for (i = 0; i < size.height; i++)
  61. {
  62. for (j = 0; j < size.width; j++)
  63. {
  64. grey = img->imageData[i*size.width + j];
  65. Gaussian_fun(u, delta, p, grey);
  66. sum_p = alpha[0] * p[0] + alpha[1] * p[1] + alpha[2] * p[2] + alpha[3] * p[3] + EPS;
  67. gamma[0] = alpha[0] * p[0] / sum_p;
  68. gamma[1] = alpha[1] * p[1] / sum_p;
  69. gamma[2] = alpha[2] * p[2] / sum_p;
  70. gamma[3] = alpha[3] * p[3] / sum_p;
  71. sum_gamma[0] += gamma[0];
  72. sum_gamma[1] += gamma[1];
  73. sum_gamma[2] += gamma[2];
  74. sum_gamma[3] += gamma[3];
  75. sum_gammay[0] += gamma[0] * grey;
  76. sum_gammay[1] += gamma[1] * grey;
  77. sum_gammay[2] += gamma[2] * grey;
  78. sum_gammay[3] += gamma[3] * grey;
  79. sum_gammayy[0] += gamma[0] * grey * grey;
  80. sum_gammayy[1] += gamma[1] * grey * grey;
  81. sum_gammayy[2] += gamma[2] * grey * grey;
  82. sum_gammayy[3] += gamma[3] * grey * grey;
  83. }
  84. }
  85. for (i = 0; i < 4; i++)
  86. {
  87. alpha_old[i] = alpha[i];
  88. u_old[i] = u[i];
  89. delta_old[i] = delta[i];
  90. alpha[i] = sum_gamma[i] / size.height / size.width;
  91. u[i] = sum_gammay[i] / (sum_gamma[i]+EPS);
  92. delta[i] = (sum_gammayy[i] - 2 * u[i] * sum_gammay[i] + u[i] * u[i] * sum_gamma[i]) / (sum_gamma[i]+EPS);
  93. }
  94. iter++;
  95. }
  96. int sizes = 256;
  97. float range[2] = { 0, 255 };
  98. float *ranges = range;
  99. CvHistogram *hist;
  100. hist = cvCreateHist(1, &sizes, CV_HIST_ARRAY, &ranges, 1);
  101. for (i = 0; i < 256; i++)
  102. {
  103. Gaussian_fun(u, delta, p, i);
  104. sum_p = alpha[0] *p[0] + alpha[1] * p[1] + alpha[2] * p[2] + alpha[3] * p[3];
  105. // cout << sum_p << endl;
  106. cvSetReal1D(hist->bins,i,cvRound(sum_p * size.width*size.height));
  107. }
  108. Drawhist(hist, hist_img, cvScalar(255,0,0,0));
  109. }
  110. int main(int argc, const char* argv[])
  111. {
  112. IplImage *img;
  113. img = cvLoadImage("..\\starry_night.jpg", 1);
  114. IplImage *imgRed = cvCreateImage(cvGetSize(img), 8, 1);
  115. IplImage *imgGreen = cvCreateImage(cvGetSize(img), 8, 1);
  116. IplImage *imgBlue = cvCreateImage(cvGetSize(img), 8, 1);
  117. cvSplit(img, imgRed, imgGreen, imgBlue, 0);
  118. namedWindow("img", CV_WINDOW_AUTOSIZE);
  119. cvShowImage("img", img);
  120. waitKey(0);
  121. int sizes = 256;
  122. float range[] = { 0, 255 };
  123. float*ranges[] = { range };
  124. CvHistogram *hist = cvCreateHist(1, &sizes, CV_HIST_ARRAY, ranges, 1);
  125. cvCalcHist(&imgRed, hist, 0, 0);
  126. IplImage *hist_img = cvCreateImage(cvSize(256, 256), IPL_DEPTH_8U, 3);
  127. Drawhist(hist, hist_img, cvScalar(0, 0, 255, 0));
  128. cvClearHist(hist);
  129. EM_GMM(imgRed, hist_img);
  130. cvNamedWindow("EM&GMM");
  131. cvShowImage("EM&GMM", hist_img);
  132. waitKey(0);
  133. cvDestroyAllWindows();
  134. return 0;
  135. }


ps:  红色为图像的直方图,蓝色为拟合曲线





参考文献

[1]李航 统计学习方法

[2]Andrew.Ng MachineLearning课件

[3]JerryLeadhttp://www.cnblogs.com/jerrylead/archive/2011/04/06/2006936.html 


易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!