2016-05-17 14 views
4

私は機械学習とテンソルフローの初心者です。テンソルフローを試みる最初のステップで、私は単純な多変量線形回帰を試みました。しかし、それは地元の最小値にとどまっているようだ。ここに私のコードです。単純な線形回帰のTensorflow

import numpy as np 
import tensorflow as tf 
import matplotlib.pyplot as plt 

def weight_variable(shape): 
    initial = tf.truncated_normal(shape, stddev=1) 
    return tf.Variable(initial) 

# dataset 
xx = np.random.randint(0,1000,[1000,3])/1000. 
yy = xx[:,0] * 2 + xx[:,1] * 1.4 + xx[:,2] * 3 

# model 
x = tf.placeholder(tf.float32, shape=[None, 3]) 
y_ = tf.placeholder(tf.float32, shape=[None]) 
W1 = weight_variable([3, 1]) 
y = tf.matmul(x, W1) 

# training and cost function 
cost_function = tf.reduce_mean(tf.square(y - y_)) 
train_function = tf.train.AdamOptimizer(1e-2).minimize(cost_function) 

# create a session 
sess = tf.Session() 

# train 
sess.run(tf.initialize_all_variables()) 
for i in range(10000): 
    sess.run(train_function, feed_dict={x:xx, y_:yy}) 
    if i % 1000 == 0: 
     print(sess.run(cost_function, feed_dict={x:xx, y_:yy})) 

出力は次のとおり

14.8449 
2.20154 
2.18375 
2.18366 
2.18366 
2.18366 
2.18366 
2.18366 
2.18366 

出力値(YY)のように2.18に添加ないノイズが存在しないことを知って、かなり大きい平均二乗誤差を有する、0から6の範囲でありますデータセット。 また、学習率0.1と1e-2のGradientDescentOptimizerを試しましたが、結果はあまり改善されません。

実装に問題はありますか?

+0

私は表示されませんerror_functionデあなたのコードに罰金が科せられます。これはcost_functionsと同じですか?学習率を下げようとしましたか? – shekkizh

+0

申し訳ありませんが、私は 'cost_function'を意味しました。私はそれを修正した。ご意見ありがとうございます。 – Firman

答えて

5

yは、y_と同じ形状ではないためです。 yは形状(1000,1)であり、y_は形状(1000)である。したがって、それらを減算すると、誤って2-Dマトリックスが作成されます。

それがあなたの費用関数を変更する修正するには:別の答えで述べたように

cost_function = tf.reduce_mean(tf.square(tf.squeeze(y) - y_)) 
-2

、uが

predictions = tf.add(b, tf.matmul(x, w)) 
error = tf.reduce_mean(tf.square(y - predictions)) 

そして、あなたはあなたがTensorflow初心者です、と言っているように使用する必要があり、あなたはここでは例を見ることができます: -

https://medium.com/@saxenarohan97/intro-to-tensorflow-solving-a-simple-regression-problem-e87b42fd4845

+0

このリンクは質問に答えるかもしれませんが、答えの本質的な部分をここに含めて参考にしてください。リンクされたページが変更された場合、リンクのみの回答は無効になります。 - [レビューの投稿](レビュー/低品質の投稿/ 18667877) – Liam

+0

ok、私は自分の答えを編集しました –

関連する問題