2017-01-15 4 views
2

SWIFTでA * X = B型の線形方程式の系を解こうとしています。タイプA * X = Bの方程式を解くdgtsv_またはsgtsv_を使用

これは、O(N^2)メモリーを消費するLU分解に基づくアルゴリズムを使用してこれを実行できました。

私の配列は一般的に大きく(10000サンプル以上)、LAPACKはO(N)メモリ空間しか消費しないトリジアゴナル行列に特有のいくつかの機能を持っています。&がより効率的です。

http://www.netlib.org/lapack/explore-html-3.4.2/d4/d62/group__double_g_tsolve.html#

基本的に、私は上記のdgtsv_またはsgtsv_機能を使って方程式を解くために探しています。しかし、私が見つけることができる例はありません。

私はSWIFTの新機能として、関数が要求する8つの入力パラメータを渡すことに苦労しています。どこかに例がありますか?

作業コードの下に(LU分解を使用して)貼り付けます。

import Accelerate 

func solve(A:[Double], _ B:[Double]) -> [Double] { 

var inMatrix:[Double] = A 

var solution:[Double] = B 

// Get the dimensions of the matrix. An NxN matrix has N^2 
// elements, so sqrt(N^2) will return N, the dimension 
var N:__CLPK_integer = __CLPK_integer(sqrt(Double(A.count))) 

// Number of columns on the RHS 
var NRHS:__CLPK_integer = 1 

// Leading dimension of A and B 
var LDA:__CLPK_integer = N 

var LDB:__CLPK_integer = N 

// Initialize some arrays for the dgetrf_(), and dgetri_() functions 
var pivots:[__CLPK_integer] = [__CLPK_integer](repeating: 0, count: Int(N)) 

var error: __CLPK_integer = 0 

// Perform LU factorization 
dgetrf_(&N, &N, &inMatrix, &N, &pivots, &error) 

// Calculate solution from LU factorization 
_ = "T".withCString { 
    dgetrs_(UnsafeMutablePointer(mutating: $0), &N, &NRHS, &inMatrix, &LDA, &pivots, &solution, &LDB, &error) 
} 
return solution 
    } 


    //Call the function 
    var A: [Double] = [ 
     1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 
     1.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 
     0.0, 1.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 
     0.0, 0.0, 1.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 
     0.0, 0.0, 0.0, 1.0, 4.0, 1.0, 0.0, 0.0, 0.0, 0.0, 
     0.0, 0.0, 0.0, 0.0, 1.0, 4.0, 1.0, 0.0, 0.0, 0.0, 
     0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 4.0, 1.0, 0.0, 0.0, 
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 4.0, 1.0, 0.0, 
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 4.0, 1.0, 
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0] 

    var b: [Double] = [0, -15, -15, -3, -3, 45, -12, -6, 0, 0] 

    var cj = solve(A: A, b) 

    print(cj) // --> [0.0, -2.9185349611542728, -3.3258601553829079, 1.2219755826859044, -4.5620421753607099, 14.026193118756936, -6.5427302996670358, 0.14472807991120964, -0.036182019977802411, 0.0] 
    //Call the function 


    //TRY LAPACK (need examples to get above solution) 
    let xx = dgtsv_(<#T##__n: UnsafeMutablePointer<__CLPK_integer>!##UnsafeMutablePointer<__CLPK_integer>!#>, <#T##__nrhs: UnsafeMutablePointer<__CLPK_integer>!##UnsafeMutablePointer<__CLPK_integer>!#>, <#T##__dl: UnsafeMutablePointer<__CLPK_doublereal>!##UnsafeMutablePointer<__CLPK_doublereal>!#>, <#T##__d__: UnsafeMutablePointer<__CLPK_doublereal>!##UnsafeMutablePointer<__CLPK_doublereal>!#>, <#T##__du: UnsafeMutablePointer<__CLPK_doublereal>!##UnsafeMutablePointer<__CLPK_doublereal>!#>, <#T##__b: UnsafeMutablePointer<__CLPK_doublereal>!##UnsafeMutablePointer<__CLPK_doublereal>!#>, <#T##__ldb: UnsafeMutablePointer<__CLPK_integer>!##UnsafeMutablePointer<__CLPK_integer>!#>, <#T##__info: UnsafeMutablePointer<__CLPK_integer>!##UnsafeMutablePointer<__CLPK_integer>!#>) 

    let xx2 = sgtsv_(<#T##__n: UnsafeMutablePointer<__CLPK_integer>!##UnsafeMutablePointer<__CLPK_integer>!#>, <#T##__nrhs: UnsafeMutablePointer<__CLPK_integer>!##UnsafeMutablePointer<__CLPK_integer>!#>, <#T##__dl: UnsafeMutablePointer<__CLPK_real>!##UnsafeMutablePointer<__CLPK_real>!#>, <#T##__d__: UnsafeMutablePointer<__CLPK_real>!##UnsafeMutablePointer<__CLPK_real>!#>, <#T##__du: UnsafeMutablePointer<__CLPK_real>!##UnsafeMutablePointer<__CLPK_real>!#>, <#T##__b: UnsafeMutablePointer<__CLPK_real>!##UnsafeMutablePointer<__CLPK_real>!#>, <#T##__ldb: UnsafeMutablePointer<__CLPK_integer>!##UnsafeMutablePointer<__CLPK_integer>!#>, <#T##__info: UnsafeMutablePointer<__CLPK_integer>!##UnsafeMutablePointer<__CLPK_integer>!#>) 
    //TRY LAPACK (need examples to get above solution) 

答えて

2

dgtsv_()は期待より低い/別個の引数としてトリ対角線 行列の対角線上部/メイン。 変数アレイのアドレスに&を渡すことができます。

すべての整数パラメータは、__CLPK_integer別名Int32 のアドレスです。

右側のベクトルbは、xから の式A x = bの方程式で上書きされます。 Aを表す3つのベクトルも同様に に上書きされるため、元のデータのコピーを作成することができます。

例:

import Swift 
import Accelerate 

var mainDiagA = [ 1.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0 ] 
var upperDiagA = [ 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ] 
var lowerDiagA = [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0 ] 

var b = [0.0, -15.0, -15.0, -3.0, -3.0, 45.0, -12.0, -6.0, 0.0, 0.0 ] 

var n = Int32(mainDiagA.count) // Order of matrix A 
var nrhs = Int32(1) // Number of right-hand sides 
var info = Int32(0) // Result code 

dgtsv_(&n, &nrhs, &lowerDiagA, &mainDiagA, &upperDiagA, &b, &n, &info) 
if info == 0 { // success 
    print(b) 
    // [0.0, -2.9185349611542732, -3.3258601553829075, 1.2219755826859044, -4.5620421753607099, 14.026193118756938, -6.5427302996670367, 0.14472807991120964, -0.036182019977802411, 0.0] 

} 
+0

おかげマーティン - これが今の私のために完璧に適しています! – Pat

関連する問題