2016-08-13 14 views
-1

ロジスティック回帰についてのScalaコードを書いています。 Scala Breezeを使用しようとしていますが、自分のコードを試してみるとエラーが出ます。Scala Breezeでロジスティック回帰をLBFGSで書くにはどうしたらいいですか?

これは私のコードです。基本的に私は同じlike hereを書こうとしました。

import breeze.linalg._ 
import breeze.optimize._ 

def lbfgsSolve(features: CSCMatrix[Double], outputs: SparseVector[Double], lambda: Double = 0.0) = { 
     val obj = new DiffFunction[SparseVector[Double]] { 
     override def calculate(weights: SparseVector[Double]): (Double, SparseVector[Double]) = { 

      def sigmoid(w: SparseVector[Double]): SparseVector[Double] = { 
      w.map(x => if (x>0) (1.0/(1.0 + Math.exp(-x))) else (Math.exp(x)/(1.0 + Math.exp(x)))) 
      } 

      val m = features.rows.toDouble 
      val z = features * weights 
      val yz = outputs :* z 
      val theta = weights.copy 
      theta(0) = 0.0 

      val out = yz.map { x => if (x > 0) Math.log(Math.exp(-x) + 1.0) else (-x + Math.log(Math.exp(x) + 1.0)) } 
      val loss = (sum(out)/m) + (0.5 * lambda * (weights.t * weights)) 

      val zz = sigmoid(yz) 
      val z0 = (zz - 1.0) :* outputs 
      val gradient = ((features.t * z0)/m) + (theta * lambda) 

      (loss, gradient) 

     } 
     } 
     val initWeights = SparseVector(Array.fill(features.cols)(1.0)) 
     new LBFGS[SparseVector[Double]](tolerance = 0.01).minimize(obj, initWeights) 

    } 

私はコードをテストするのにthis datasetを使用しています。 「承認」はアウトプット、その他はフィーチャーです。私はすでにこの機能(en.wikipedia.org/wiki/Feature_scaling)で機能を正規化し、最初の列に定数を追加しました。だから、私の特徴行列はこのように見えます。

scala> features(0 until 4, 0 until 4) 
res18: breeze.linalg.SliceMatrix[Int,Int,Double] = 
1.0 0.27586206896551724 0.7758620689655172 0.6666666666666666 
1.0 0.7586206896551724 0.8103448275862069 0.6666666666666666 
1.0 1.0     1.0     0.0 
1.0 0.7241379310344828 0.5344827586206897 1.0 

この機能を機能に適用すると、エラーが返されます。

scala> val answer = lbfgsSolve(features, outputs, 0.05) 
[run-main-0] INFO breeze.optimize.LBFGS - Step Size: 14.07 
[run-main-0] INFO breeze.optimize.LBFGS - Val and Grad Norm: 0.566596 (rel: 0.0459) 0.0357693 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 0.1 fval: 0.5684918452517186 rhs: 0.5665961695023995 cdd: -0.024619310343914663 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 0.010859375861321754 fval: 0.5667857913421284 rhs: 0.5665963991719202 cdd: -0.025639964206204915 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 0.0012166564062106571 fval: 0.5666174432637753 rhs: 0.5665964240162465 cdd: -0.025750856078277946 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.36780897002546E-4 fval: 0.5665987873427688 rhs: 0.5665964267985301 cdd: -0.025763280624522322 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.538334048641963E-5 fval: 0.5665966925628935 rhs: 0.5665964271113091 cdd: -0.025764677442936045 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.730193655418023E-6 fval: 0.5665964570019302 rhs: 0.5665964271464863 cdd: -0.025764834539024242 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.9459913173710706E-7 fval: 0.5665964305083527 rhs: 0.5665964271504427 cdd: -0.025764852207922004 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 2.1887054476776026E-8 fval: 0.5665964275285602 rhs: 0.5665964271508876 cdd: -0.02576485419518635 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 2.4616922533245922E-9 fval: 0.5665964271934153 rhs: 0.5665964271509377 cdd: -0.025764854418698992 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 2.7687332010108064E-10 fval: 0.5665964271557209 rhs: 0.5665964271509434 cdd: -0.025764854443838008 
[run-main-0] ERROR breeze.optimize.LBFGS - Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 0.10402596383408635 fval: 0.5666979644593371 rhs: 0.5665964138414288 cdd: -0.0012757889755073697 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 0.011003739451576303 fval: 0.5666071180772083 rhs: 0.5665964257430798 cdd: -0.0012790548704777202 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 0.001166003720069149 fval: 0.5665975594520045 rhs: 0.5665964270017607 cdd: -0.001279400679810798 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.2357771173880845E-4 fval: 0.5665965471504739 rhs: 0.566596427135133 cdd: -0.0012794373271490386 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.309751490251532E-5 fval: 0.5665964398691514 rhs: 0.5665964271492683 cdd: -0.0012794412112233761 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.388156908600053E-6 fval: 0.5665964284988991 rhs: 0.5665964271507664 cdd: -0.0012794416228816064 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.4712560321674266E-7 fval: 0.5665964272938085 rhs: 0.5665964271509252 cdd: -0.0012794416665117385 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.5593311038744415E-8 fval: 0.5665964271660856 rhs: 0.566596427150942 cdd: -0.0012794416711359338 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.6526860200416176E-9 fval: 0.5665964271525491 rhs: 0.5665964271509438 cdd: -0.0012794416716260355 
[run-main-0] INFO breeze.optimize.StrongWolfeLineSearch - Line search t: 1.7514997931281254E-10 fval: 0.5665964271511141 rhs: 0.566596427150944 cdd: -0.00127944167167798 
[run-main-0] ERROR breeze.optimize.LBFGS - Failure again! Giving up and returning. Maybe the objective is just poorly behaved? 
[run-main-0] INFO breeze.optimize.LBFGS - Converged because line search failed! 
answer: breeze.linalg.SparseVector[Double] = SparseVector((0,1.281051822587718), (1,0.47902016540238035), (2,0.4807986641770212), (3,0.3835424930764545)) 

この問題を解決して動作させるにはどうすればよいですか? 可能であれば、私はCSCMatrixとSparseVectorを使いたいです。私はあまりにも慎重に見ずにScalaの2.11.8、そよ風0.12、SLF4J - シンプル1.7.6とOpenJDKの-1.8.0

答えて

0

を使用しています

、あなたの勾配はおそらく間違っています。 GradientTesterクラスを使用すると、

の診断情報が出力されます。
関連する問題