2016-12-28 14 views
1

私は2つのガウス分布をプロットし、交差点を見つけることを試みています。私は次のコードを持っています。それは正確な交差点をプロットしていないと私は本当に理由を把握することはできません。それはほんのわずかに外れているようですが、引き算されたガウス分布のログをとった場合は、派生した解を試してみました。正しいと思われるようです。誰も助けることができますか?どうもありがとうございます!ガウスの交点

import numpy as np 
import matplotlib.pyplot as plt 

def plot_normal(x, mean = 0, sigma = 1): 
    return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2)) 

# found online 
def solve_gasussians(m1, s1, m2, s2): 
    a = 1.0/(2.0*s1**2) - 1.0/(2.0*s2**2) 
    b = m2/(s2**2) - m1/(s1**2) 
    c = m1**2 /(2*s1**2) - m2**2/(2.0*s2**2) - np.log(s2/s1) 
    return np.roots([a,b,c]) 

s1 = np.linspace(0, 10,300) 
s2 = np.linspace(0, 14, 300) 

solved_val = solve_gasussians(5.0, 0.5, 7.0, 1.0) 
print solved_val 
solved_val = solved_val[0] 
plt.figure('Baseline Distributions') 
plt.title('Baseline Distributions') 
plt.xlabel('Response Rate') 
plt.ylabel('Probability') 
plt.plot(s1, plot_normal(s1, 5.0, 0.5),'r', label='s1') 
plt.plot(s2, plot_normal(s2, 7.0, 1.0),'b', label='s2') 
plt.plot(solved_val, plot_normal(solved_val, 7.0, 1.0), 'mo') 
plt.legend() 
plt.show() 
+0

あなたは私たちがするように、あなたがオンラインで見つけるソリューションに私たちを指すことができます自分のためにそれを派生させる必要はありませんか? –

+0

私は彼らが言及している解決策は、この[あなたの質問](http://stackoverflow.com/a/22579904/752843)かもしれないと思います。だから私たちは、コードにコメントすることの怠慢な欠如を完全に非難することはできません。 – Richard

+0

@リチャード、それも私の考えです。 –

答えて

0

をあなたはplot_normal機能の小さなバグを持っている - あなたは分母に平方根を欠けています。適切なバージョン: enter image description here

そして2備考:

def plot_normal(x, mean = 0, sigma = 1): 
    return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2)) 

は、期待される結果を提供します。

  1. 一般的に式の2つの根(2つの交点)を持つことができます。これは、指定したパラメータでのケースです。
  2. は、私の知る限りnp.rootsはあなたのおおよその結果を与える知っているが、あなたは猫のようsolve_gasussians機能を書き換え、簡単に正確な結果を得るよう:

    def solve_gasussians(m1, s1, m2, s2): 
        # coefficients of quadratic equation ax^2 + bx + c = 0 
        a = (s1**2.0) - (s2**2.0) 
        b = 2 * (m1 * s2**2.0 - m2 * s1**2.0) 
        c = m2**2.0 * s1**2.0 - m1**2.0 * s2**2.0 - 2 * s1**2.0 * s2**2.0 * np.log(s1/s2) 
        x1 = (-b + np.sqrt(b**2.0 - 4.0 * a * c))/(2.0 * a) 
        x2 = (-b - np.sqrt(b**2.0 - 4.0 * a * c))/(2.0 * a) 
        return x1, x2 
    
0

私はあなたのコードのどこに間違いがあるのか​​わかりません。しかし、私はあなたが借りたコードを見つけ出し、必要な調整の一部を行ったと思います。

import numpy as np 
import matplotlib.pyplot as plt 
from scipy.stats import norm 

def solve(m1,m2,std1,std2): 
    a = 1/(2*std1**2) - 1/(2*std2**2) 
    b = m2/(std2**2) - m1/(std1**2) 
    c = m1**2 /(2*std1**2) - m2**2/(2*std2**2) - np.log(std2/std1) 
    return np.roots([a,b,c]) 

m1 = 5 
std1 = 0.5 
m2 = 7 
std2 = 1 

result = solve(m1,m2,std1,std2) 

x = np.linspace(-5,9,10000) 
plot1=plt.plot(x,[norm.pdf(_,m1,std1) for _ in x]) 
plot2=plt.plot(x,[norm.pdf(_,m2,std2) for _ in x]) 
plot3=plt.plot(result[0],norm.pdf(result[0],m1,std1) ,'o') 

plt.show() 

私は(彼らが私のために行うように)あなたのための生活を楽かもしれない迷惑なアドバイスの2枚提供されます:

  • あなたがコードを適応させる

    が小さい、増分変更をしようとし、各ステップでコードが動作することを確認してください。
  • 既存のフリーライブラリを探します。この場合、normからscipyは、元のコードで使用されていたものの良い代替品です。
0

ここに間違いがあります。このライン:

def plot_normal(x, mean = 0, sigma = 1): 
    return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2)) 

あなたはsqrtを忘れてしまった:

def plot_normal(x, mean = 0, sigma = 1): 
    return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2)) 

がこのであるべき。

それが利用できる場合には、次のような、既存の通常のPDFファイルを使用することが賢明だろう:それは正確に交差点のために解決することも可能です

import scipy.stats 
def plot_normal(x, mean = 0, sigma = 1): 
    return scipy.stats.norm.pdf(x,loc=mean,scale=sigma) 

This answerは、ガウスの交点の根の二次方程式を提供します。 maximaを使ってxを解くと、次の式が得られます。これは複雑ではあるが、反復メソッドに依存せず、簡単な式から自動的に生成することができます。完全にそれを置く

def solve_gaussians(m1,s1,m2,s2): 
    x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2) 
    x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2) 
    return x1,x2 

が与える:

import numpy as np 
import matplotlib.pyplot as plt 
import scipy.stats 

def plot_normal(x, mean = 0, sigma = 1): 
    return scipy.stats.norm.pdf(x,loc=mean,scale=sigma) 

#Use the equation from [this answer](https://stats.stackexchange.com/a/12213/12116) solved for x 
def solve_gaussians(m1,s1,m2,s2): 
    x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2) 
    x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2) 
    return x1,x2 

s = np.linspace(0, 14,300) 
x = solve_gaussians(5.0,0.5,7.0,1.0) 

plt.figure('Baseline Distributions') 
plt.title('Baseline Distributions') 
plt.xlabel('Response Rate') 
plt.ylabel('Probability') 
plt.plot(s, plot_normal(s, 5.0, 0.5),'r', label='s1') 
plt.plot(s, plot_normal(s, 7.0, 1.0),'b', label='s2') 
plt.plot(x[0],plot_normal(x[0],5.,0.5),'mo') 
plt.plot(x[1],plot_normal(x[1],5.,0.5),'mo') 
plt.legend() 
plt.show() 

を与える:

Intersection of Gaussians