これは、1つの変数にグラデーションディセントを実装する2番目の試みであり、常に発散します。何か案は?単純な勾配降下はなぜ発散するのですか?
これは、1つの変数の残余平方和を最小にする単純な線形回帰です。
def gradient_descent_wtf(xvalues, yvalues):
tolerance = 0.1
#y=mx+b
#some line to predict y values from x values
m=1.
b=1.
#a predicted y-value has value mx + b
for i in range(0,10):
#calculate y-value predictions for all x-values
predicted_yvalues = list()
for x in xvalues:
predicted_yvalues.append(m*x + b)
# predicted_yvalues holds the predicted y-values
#now calculate the residuals = y-value - predicted y-value for each point
residuals = list()
number_of_points = len(yvalues)
for n in range(0,number_of_points):
residuals.append(yvalues[n] - predicted_yvalues[n])
## calculate the residual sum of squares from the residuals, that is,
## square each residual and add them all up. we will try to minimize
## the residual sum of squares later.
residual_sum_of_squares = 0.
for r in residuals:
residual_sum_of_squares += r**2
print("RSS = %s" % residual_sum_of_squares)
##
##
##
#now make a version of the residuals which is multiplied by the x-values
residuals_times_xvalues = list()
for n in range(0,number_of_points):
residuals_times_xvalues.append(residuals[n] * xvalues[n])
#now create the sums for the residuals and for the residuals times the x-values
residuals_sum = sum(residuals)
residuals_times_xvalues_sum = sum(residuals_times_xvalues)
# now multiply the sums by a positive scalar and add each to m and b.
residuals_sum *= 0.1
residuals_times_xvalues_sum *= 0.1
b += residuals_sum
m += residuals_times_xvalues_sum
#and repeat until convergence.
#convergence occurs when ||sum vector|| < some tolerance.
# ||sum vector|| = sqrt(residuals_sum**2 + residuals_times_xvalues_sum**2)
#check for convergence
magnitude_of_sum_vector = (residuals_sum**2 + residuals_times_xvalues_sum**2)**0.5
if magnitude_of_sum_vector < tolerance:
break
return (b, m)
結果:
gradient_descent_wtf([1,2,3,4,5,6,7,8,9,10],[6,23,8,56,3,24,234,76,59,567])
RSS = 370433.0
RSS = 300170125.7
RSS = 4.86943013045e+11
RSS = 7.90447409339e+14
RSS = 1.28312217794e+18
RSS = 2.08287421094e+21
RSS = 3.38110045417e+24
RSS = 5.48849288217e+27
RSS = 8.90939341376e+30
RSS = 1.44624932026e+34
Out[108]:
(-3.475524066284303e+16, -2.4195981188763203e+17)