2

私自身の弾性ネットソルバを書くためのウォームアップとして、座標降下を使用して実装された一般最小二乗法の高速版を取得しようとしています。ジュリアの最小二乗法の座標降下アルゴリズムが収束しない

私は私が正しく座標降下アルゴリズムを実装しましたが、私は「高速」バージョン(下記参照)、を使用する場合、アルゴリズムは日常64ビット浮動小数点をオーバーフロー回帰係数を出力し、めちゃくちゃ不安定であると考えていますフィーチャの数はサンプルの数に比べて適度なサイズである。

線形回帰およびOLS

Aは行列、未知の回帰係数のXAベクトルであり、yは出力であるB = A * X、もし、私は

を最小限にXを検索します|| b - Ax || 2 2

A [j]がAのj番目の列であり、A [-j]が列jを持たないAであり、Aの列が|| A [j ] ||^2 = 1であれば、座標方向の更新は

座標降下:

x[j] <-- A[j]^T * (b - A[-j] * x[-j]) 

私はthese notes (page 9-10)に従っていますが、その導出は簡単な計算です。

代わりに[J]^T(B - [-j] * X [-j])再計算のことを指摘したの座標のすべての時間を、それを行うにはより高速な方法は、高速

であります下降:

x[j] <-- A[j]^T*r + x[j] 

ここで、合計残差r = b-Axは、ループオーバー座標の外側で計算されます。これらの更新規則の等価性は、Ax = A [j] * x [j] + A [-j] * x [-j]に注目し、用語を並べ替えることに続く。

私の問題は、2番目の方法は確かに高速ですが、フィーチャの数がサンプル数に比べて少なくないときはいつも、数値的に不安定であることです。なぜそうなのか誰かが何らかの洞察を持っているのではないかと思っていました。より安定した最初の方法は、フィーチャの数がサンプル数に近づくにつれて、より標準的な方法にはまだ反対し始めます。

ジュリアコードの下

は、2つの更新ルールのためのいくつかのジュリア・コードです:私は次のようにデータを生成する問題

function OLS_builtin(A,b) 
    x = A\b 
    return(x) 
end 

function OLS_coord_descent(A,b)  
    N,P = size(A) 
    x = zeros(P) 
    for cycle in 1:1000 
     for j = 1:P 
      x[j] = dot(A[:,j], b - A[:,1:P .!= j]*x[1:P .!= j]) 
     end  
    end 
    return(x) 
end 

function OLS_coord_descent_fast(A,b) 
    N,P = size(A) 
    x = zeros(P) 
    for cycle in 1:1000 
     r = b - A*x 
     for j = 1:P 
      x[j] += dot(A[:,j],r) 
     end  
    end 
    return(x) 
end 

例:

n = 100 
p = 50 
σ = 0.1 
β_nz = float([i*(-1)^i for i in 1:10]) 

β = append!(β_nz,zeros(Float64,p-length(β_nz))) 
X = randn(n,p); X .-= mean(X,1); X ./= sqrt(sum(abs2(X),1)) 
y = X*β + σ*randn(n); y .-= mean(y); 

ここで私はp = 50を使用し、OLS_coord_descent(X,y)との間で良好な一致が得られますですが、OLS_coord_descent_fast(X,y)は回帰係数の指数関数的に大きな値を返します。

pが約20未満の場合、OLS_coord_descent_fast(X,y)は他の2つに一致します。

予想

物事がp < < n個の政権のために同意しているので、私はアルゴリズムが正式に正しいが、数値的に不安定だと思います。誰もが、この推測が正しいかどうかについての考えを持っていますか?そうであれば、アルゴリズムの高速バージョンのパフォーマンス向上を(ほとんど)維持しながら、不安定性を修正する方法はありますか?

+0

これはhttp://stats.stackexchange.com/questions/251920/coordinate-descent-in-ordinary-least-squares-notのクロスポストです - 収束する。そのような投稿をしないでください。 StackOverflowはプログラミング問題のためのものであり、これはアルゴリズムの問​​題のようです。私はこれが実際には計算科学専攻に最も適していると思っていますが、3回目の投稿の代わりにマイグレーションされるべきだと思います。 –

+0

ええ、それはアルゴリズム、統計、そしておそらくJuliaの言語特有の問題の奇妙な交差点なので、どこに置くべきかわかりませんでした。より適切な場所に移行したい場合は、してください。 – Rory

答えて

4

クイックアンサー:更新するたびにを更新しています。x[j]が更新されました。 OLS_coord_descentように振る舞う固定機能を以下に示します。

function OLS_coord_descent_fast(A,b) 
    N,P = size(A) 
    x = zeros(P) 
    for cycle in 1:1000 
     r = b - A*x 
     for j = 1:P 
      x[j] += dot(A[:,j],r) 
      r -= A[:,j]*dot(A[:,j],r) # Add this line 
     end  
    end 
    return(x) 
end 
+0

また、 'dot(A [:、j]、r)'はループ内の2つの更新の共通部分式なので、1つの計算に折り返すことができます。内部ループ全体は依然としてO(N)の計算です。 –

+1

はい!とにかくあなたの修正 'r'が更新されるので、サイクルループの外側で' b'に初期化する必要があるので、 'r = b - A * x'算術も切り捨てることができると思います。 – Rory

関連する問題