一、机器学习中的参数估计问题

史上简单易学的机器学习算法——EM算法 缘木求鱼_python

二、EM算法简介

    在上述存在隐变量的问题中,不能直接通过极大似然估计求出模型中的参数,EM算法是一种解决存在隐含变量优化问题的有效方法。EM算法是期望极大(Expectation Maximization)算法的简称,EM算法是一种迭代型的算法,在每一次的迭代过程中,主要分为两步:即求期望(Expectation)步骤和最大化(Maximization)步骤。

三、EM算法推导的准备

史上简单易学的机器学习算法——EM算法 缘木求鱼_数学期望_02

史上简单易学的机器学习算法——EM算法 缘木求鱼_迭代_03

(图片来自参考文章1)

注:若函数

史上简单易学的机器学习算法——EM算法 缘木求鱼_html_04

是凹函数,上述的符号相反。

3、数学期望

史上简单易学的机器学习算法——EM算法 缘木求鱼_数学期望_05

四、EM算法的求解过程   

史上简单易学的机器学习算法——EM算法 缘木求鱼_html_06

史上简单易学的机器学习算法——EM算法 缘木求鱼_数学期望_07

史上简单易学的机器学习算法——EM算法 缘木求鱼_数学期望_08

五、EM算法的收敛性保证

史上简单易学的机器学习算法——EM算法 缘木求鱼_迭代_09

六、利用EM算法参数求解实例

   

史上简单易学的机器学习算法——EM算法 缘木求鱼_html_10

Python代码

#coding:UTF-8
'''
Created on 2015年6月7日

@author: zhaozhiyong
'''
from __future__ import division
from numpy import *
import math as mt
#首先生成一些用于测试的样本
#指定两个高斯分布的参数,这两个高斯分布的方差相同
sigma = 6
miu_1 = 40
miu_2 = 20

#随机均匀选择两个高斯分布,用于生成样本值
N = 1000
X = zeros((1, N))
for i in xrange(N):
    if random.random() > 0.5:#使用的是numpy模块中的random
        X[0, i] = random.randn() * sigma + miu_1
    else:
        X[0, i] = random.randn() * sigma + miu_2

#上述步骤已经生成样本
#对生成的样本,使用EM算法计算其均值miu

#取miu的初始值
k = 2
miu = random.random((1, k))
#miu = mat([40.0, 20.0])
Expectations = zeros((N, k))

for step in xrange(1000):#设置迭代次数
    #步骤1,计算期望
    for i in xrange(N):
        #计算分母
        denominator = 0
        for j in xrange(k):
            denominator = denominator + mt.exp(-1 / (2 * sigma ** 2) * (X[0, i] - miu[0, j]) ** 2)
        
        #计算分子
        for j in xrange(k):
            numerator = mt.exp(-1 / (2 * sigma ** 2) * (X[0, i] - miu[0, j]) ** 2)
            Expectations[i, j] = numerator / denominator
    
    #步骤2,求期望的最大
    #oldMiu = miu
    oldMiu = zeros((1, k))
    for j in xrange(k):
        oldMiu[0, j] = miu[0, j]
        numerator = 0
        denominator = 0
        for i in xrange(N):
            numerator = numerator + Expectations[i, j] * X[0, i]
            denominator = denominator + Expectations[i, j]
        miu[0, j] = numerator / denominator
        
    
    #判断是否满足要求
    epsilon = 0.0001
    if sum(abs(miu - oldMiu)) < epsilon:
        break
    
    print step
    print miu
    
print miu

最终结果

[[ 40.49487592  19.96497512]]