2017-04-12 9 views
1

ベータ版の配布には_stats関数をオーバーロードする必要があります。これは私の現在のコードです:私のカスタムオブジェクトから_rvs()を取得Python scipy overloadベータ版の_stats関数

from scipy.stats import beta 
import scipy.stats as st 

class CustomBeta(st.rv_continuous): 
    def _stats(self, a, b): 
     # will add own code here 
     mn = a * 1.0/(a + b) 
     var = (a * b * 1.0)/(a + b + 1.0)/(a + b) ** 2.0 
     g1 = 2.0 * (b - a) * sqrt((1.0 + a + b)/(a * b))/(2 + a + b) 
     g2 = 6.0 * (a ** 3 + a ** 2 * (1 - 2 * b) + b ** 2 * (1 + b) - 2 * a * b * (2 + b)) 
     g2 /= a * b * (a + b + 2) * (a + b + 3) 
     return mn, var, g1, g2 

dist = beta(4, 6) 
print dist.rvs() # works fine 


dist = CustomBeta(4, 6) 
print dist.rvs() # crashes 

は私に長いスタックトレースとエラー

RuntimeError: maximum recursion depth exceeded

答えて

2

を与えるこれは_statsを過負荷とは何の関係もありません。同じ動作を

class CustomBeta(st.rv_continuous): 
    pass 

dist = CustomBeta(4, 6) 
print(dist.rvs()) # crashes 

documentation of rv_continuous状態によって単に

New random variables can be defined by subclassing the rv_continuous class and re-defining at least the _pdf or the _cdf method.

あなたは確率密度関数(PDF)または累積確率密度を計算するために、これらの方法のうちの少なくとも1つを提供する必要があることに起因します関数(cdf)です。これは次のように使用されている

さらに、

[ rv_continuous ] cannot be used directly as a distribution.

class CustomBetaGen(st.rv_continuous): 
    ... 

CustomBeta = CustomBetaGen(name='CustomBeta') 

dist = CustomBeta(4, 6) 

最後に、rvs.()はあなたが_rvs方法を提供しない場合はベータ分布のために正常に動作していないようです。

一緒にすべてを置くとベータ分布から適切な方法を盗む:私はこのエラーを取得する

from scipy.stats import beta 
import scipy.stats as st 
import numpy as np 

class CustomBetaGen(st.rv_continuous): 
    def _cdf(self, x, a, b): 
     return beta.cdf(x, a, b) 
    def _pdf(self, x, a, b): 
     return beta.pdf(x, a, b) 
    def _rvs(self, a, b): 
     return beta.rvs(a, b) 
    def _stats(self, a, b): 
     # will add own code here 
     mn = a * 1.0/(a + b) 
     var = (a * b * 1.0)/(a + b + 1.0)/(a + b) ** 2.0 
     g1 = 2.0 * (b - a) * np.sqrt((1.0 + a + b)/(a * b))/(2 + a + b) 
     g2 = 6.0 * (a ** 3 + a ** 2 * (1 - 2 * b) + b ** 2 * (1 + b) - 2 * a * b * (2 + b)) 
     g2 /= a * b * (a + b + 2) * (a + b + 3) 
     return mn, var, g1, g2 

CustomBeta = CustomBetaGen(name='CustomBeta') 

dist = beta(4, 6) 
print(dist.rvs()) # works fine 
print(dist.stats()) # (array(0.4), array(0.021818181818181816)) 

dist = CustomBeta(4, 6) 
print(dist.rvs()) # works fine 
print(dist.stats()) # (array(0.4), array(0.021818181818181816)) 
+0

を、はAttributeErrorは:「beta_gen」オブジェクトがN00B @無属性「_size」 – n00b

+0

を持っていない私も、この奇妙なエラーが発生しましたが、それをコードで修正するのを忘れた。 'return beta._rvs(a、b)'を 'return beta.rvs(a、b)'に変更してみてください。 – kazemakase

関連する問題