1

私は、Siraj Ravalの動画を使って機械学習を始めました。ビデオ "Intro - The Intelligence of Intelligence"は、kaggle.comのデータセットを使ってGradient Descentを使用して線形回帰を実行することを試みました。これは私のコードです:Pythonで線形回帰を実装する

""" 
An Example of a Linear Regression model. 

Here i am taking an example from https://www.kaggle.com/alopez247/pokemon 
to find a relation between variable "Total" and "HP". 

""" 
import numpy as np 
import pandas as pd 
from matplotlib import pyplot as plt 
import sys 
import os 

data = pd.read_csv("./pokemon_alopez247.csv") 
d = {"Total": data['Total'], 
    "HP": data['HP']} 
smallData = pd.DataFrame(d) 
test = smallData.values 
epsilon = 0.001 


def compute_error_for_line(b, m, points): 
    """Return the Error for Line given the points.""" 
    totalError = 0 
    for i in range(0, len(points)): 
     x = test[i, 0] 
     y = test[i, 1] 
     totalError += (y - (m * x + b)) ** 2 
    return totalError/float(len(points)) 


def step_gradient(b_current, m_current, points, learningRate): 
    """Return the new b and m points.""" 
    b_gradient = 0 
    m_gradient = 0 
    N = float(len(points)) 
    for i in range(0, len(points)): 
     x = points[i, 0] 
     y = points[i, 1] 
     error = y - ((m_current * x) + b_current) 
     b_gradient += -(2/N) * error 
     m_gradient += -(2/N) * x * error 
    new_b = b_current - (learningRate * b_gradient) 
    new_m = m_current - (learningRate * m_gradient) 
    return [new_b, new_m] 


def main(): 
    """Return and plot function here.""" 
    plt.figure(num=None, figsize=(20, 10), dpi=80, 
       facecolor='w', edgecolor='k') 
    plt.axis([0, 780, 0, 260]) 
    plt.ylabel("Total") 
    plt.xlabel("HP") 
    plt.scatter(test[:, [1]], test[:, [0]], c='r', s=1) 

    m = 0.3 
    b = -30 
    x = np.arange(800) 
    y = m * x + b 
    for i in range(30): 
     error = compute_error_for_line(b, m, test) 
     print("error :", error) 
     if(error > epsilon): 
      y = m * x + b 
      plt.plot(x, y) 
      b, m = step_gradient(b, m, test, 0.0001) 
      print("b , m :", b, ",", m) 
      plt.pause(0.01) 

    plt.show() 

    plt.pause(0.001) 

if __name__ == '__main__': 
    try: 
     main() 
    except KeyboardInterrupt: 
     print('Interrupted') 
     try: 
      sys.exit(0) 
     except SystemExit: 
      os._exit(0) 

、出力は次のようになります。

error : 193676.072288 
b , m : -29.91451362 , 6.46934413315 
/usr/local/lib/python3.5/dist-packages/matplotlib/backend_bases.py:2445: MatplotlibDeprecationWarning: Using default event loop until function specific to this GUI is implemented 
    warnings.warn(str, mplDeprecation) 
error : 16427.2683093 
b , m : -29.9134163218 , 6.04491523016 
error : 15588.2873385 
b , m : -29.9065147511 , 6.07401898958 
error : 15583.8939554 
b , m : -29.9000125838 , 6.07192788394 
error : 15583.4489928 
b , m : -29.8934831191 , 6.07198242461 
error : 15583.0227312 
b , m : -29.8869557061 , 6.07188938575 
error : 15582.5965792 
b , m : -29.8804283262 , 6.07180649992 
error : 15582.1704489 
b , m : -29.8739011182 , 6.07172291798 
error : 15581.74434 
b , m : -29.8673740726 , 6.07163938615 
error : 15581.3182523 
b , m : -29.86084719 , 6.0715558531 
error : 15580.8921858 
b , m : -29.8543204704 , 6.07147232236 
error : 15580.4661407 
b , m : -29.8477939138 , 6.0713887937 
error : 15580.0401168 
b , m : -29.8412675201 , 6.07130526712 
error : 15579.6141143 
b , m : -29.8347412894 , 6.07122174263 
error : 15579.1881329 
b , m : -29.8282152217 , 6.07113822022 
error : 15578.7621729 
b , m : -29.821689317 , 6.0710546999 
error : 15578.3362341 
b , m : -29.8151635752 , 6.07097118166 
error : 15577.9103166 
b , m : -29.8086379963 , 6.07088766551 
error : 15577.4844204 
b , m : -29.8021125804 , 6.07080415145 
error : 15577.0585455 
b , m : -29.7955873275 , 6.07072063947 
error : 15576.6326918 
b , m : -29.7890622375 , 6.07063712957 
error : 15576.2068594 
b , m : -29.7825373104 , 6.07055362176 
error : 15575.7810482 
b , m : -29.7760125462 , 6.07047011604 
error : 15575.3552583 
b , m : -29.769487945 , 6.0703866124 
error : 15574.9294897 
b , m : -29.7629635067 , 6.07030311084 
error : 15574.5037423 
b , m : -29.7564392314 , 6.07021961138 
error : 15574.0780162 
b , m : -29.7499151189 , 6.07013611399 
error : 15573.6523114 
b , m : -29.7433911694 , 6.07005261869 
error : 15573.2266278 
b , m : -29.7368673827 , 6.06996912548 
error : 15572.8009655 
b , m : -29.730343759 , 6.06988563435 
[Finished in 73.209s] 

ので、出力はすべてが計画どおりに起こっていることを示唆しています。しかし、thisを見てください。最初の青は元の値であり、線は遠くになっています!私はcompute_error_for_line関数とstep_gradient関数を書き直そうとしましたが、まだ何もありませんでした。 最後までお読みいただきありがとうございます。

私のサンプルスペースに最適なラインのパラメータをどのように達成できますか?

私のCSVファイルhereにリンクしてください(このファイルの有効期限は22時間です)。

+0

勾配関数は 'compute_error_for_line'と同じようにエラーを計算しません(勾配関数は誤差を二乗しません。他の関数は二乗誤差を行います)。[1]を[0]に変更すると、 。それは目的ですか?実際には – kbrose

+0

@kbroseです。私はstep_gradient関数(bとm resptに関して)の部分的な導関数を取っています。 – Roofi

答えて

1
plt.scatter(test[:, [1]], test[:, [0]], c='r', s=1) 

xとyの値を入れ替えたようです。

+0

これらのタイプのバグを見つける方法についての直感は? – Roofi

+1

私はグラデーションの更新手順を見て、それはうまく見えたので、外れ値があったかもしれないと思いました。私はあなたのコードを見て、どこのデータを省略していたのかを確認していましたが、それを見ました。私は一般的に何かヒントはありますか –