2016-03-23 9 views
3

ガウス混合のEMアルゴリズムを実装しようとしています。 - 私は、次のチュートリアルに自分自身を基づか:EMアルゴリズムによるバッファオーバーフロー

http://people.duke.edu/~ccc14/sta-663/EMAlgorithm.html

アルゴリズムは、いくつかの値で正常に動作しますが、数が大きくなっているとき、私は通常の行列の無限大を取得します。

Here'sは私のコード:

from scipy.stats import multivariate_normal as mvn 
import numpy as np 
import matplotlib.pyplot as plt 

def em_gmm_orig(xs, pis, mus, sigmas, tol=0.01, max_iter=100): 

    n, p = xs.shape 
    k = len(pis) 

    ll_old = 0 
    for i in range(max_iter): 
     exp_A = [] 
     exp_B = [] 
     ll_new = 0 

     # E-step 
     ws = np.zeros((k, n)) 
     for j in range(len(mus)): 
      for i in range(n): 
       ws[j, i] = pis[j] * mvn(mus[j], sigmas[j]).pdf(xs[i]) 
     ws /= ws.sum(0) 

     # M-step 
     pis = np.zeros(k) 
     for j in range(len(mus)): 
      for i in range(n): 
       pis[j] += ws[j, i] 
     pis /= n 

     mus = np.zeros((k, p)) 
     for j in range(k): 
      for i in range(n): 
       mus[j] += ws[j, i] * xs[i] 
      mus[j] /= ws[j, :].sum() 

     sigmas = np.zeros((k, p, p)) 
     for j in range(k): 
      for i in range(n): 
       ys = np.reshape(xs[i]- mus[j], (2,1)) 
       sigmas[j] += ws[j, i] * np.dot(ys, ys.T) 
      sigmas[j] /= ws[j,:].sum() 

     # update complete log likelihoood 
     ll_new = 0.0 
     for i in range(n): 
      s = 0 
      for j in range(k): 
       s += pis[j] * mvn(mus[j], sigmas[j]).pdf(xs[i]) 
      ll_new += np.log(s) 

     if np.abs(ll_new - ll_old) < tol: 
      break 
     ll_old = ll_new 

    return ll_new, pis, mus, sigmas 


np.random.seed(123) 

# create data set 
n = 100 
for i in range(0,4): 
    num1 = np.random.uniform(-500,500) 
    num2 = np.random.uniform(-500,500) 
    num3 = np.random.uniform(-500,500) 
    num4 = np.random.uniform(-500,500) 
_mus = np.array([[num1,num2], [num3,num4]]) 

_sigmas = np.array([[[3, 0], [0, 0.5]], [[1,0],[0,2]]]) 
_pis = np.array([0.6, 0.4]) 
xs = np.concatenate([np.random.multivariate_normal(mu, sigma, int(pi*n)) 
        for pi, mu, sigma in zip(_pis, _mus, _sigmas)]) 

# initial guesses for parameters 
num_clusters = 4 
pis = np.random.random(num_clusters) 
pis /= pis.sum() 
mus = np.random.random((num_clusters,2)) 
print(mus) 
sigmas = np.array([np.eye(2)] * num_clusters) 

ll1, pis1, mus1, sigmas1 = em_gmm_orig(xs, pis, mus, sigmas) 

#Ploy 
intervals = 101 
ys = np.linspace(-8,8,intervals) 
X, Y = np.meshgrid(ys, ys) 
_ys = np.vstack([X.ravel(), Y.ravel()]).T 

z = np.zeros(len(_ys)) 
for pi, mu, sigma in zip(pis1, mus1, sigmas1): 
    z += pi*mvn(mu, sigma).pdf(_ys) 
z = z.reshape((intervals, intervals)) 

ax = plt.subplot(111) 
plt.scatter(xs[:,0], xs[:,1], alpha=0.2) 
plt.contour(X, Y, z, N=10) 
plt.axis([-8,6,-6,8]) 
ax.axes.set_aspect('equal') 
plt.tight_layout()  

- 私は通常、次のエラーを取得:

Traceback (most recent call last): 
     File "C:\Usy", line 82, in <module> 
     ll1, pis1, mus1, sigmas1 = em_gmm_orig(xs, pis, mus, sigmas) 
     File "C:\Usy", line 48, in em_gmm_orig 
     s += pis[j] * mvn(mus[j], sigmas[j]).pdf(xs[i]) 
     File "y", line 354, in __call__ 
     allow_singular=allow_singular) 
... 
    ValueError: array must not contain infs or NaNs 

は、私はこの問題を解決することができる方法はありますか?

答えて

0

おそらく、あなたのシグマ値は、特異点を与えるmvn(mus [j]、sigmas [j]).pdf(xs [i])ではゼロです。あなたのシグマ値をチェックしたいかもしれません。

関連する問題