0
この割り当てのタスクは、レイヤの入力に対する損失の偏微分を計算することです。チェーンルールを実装する必要があります。レイヤーの入力に対する損失の偏微分を計算します。チェーンルール| Python
私は概念的にどのように機能を設定するのが難しいのですか。アドバイスやヒントをいただければ幸いです!
関数変数のサンプルデータは、最下部にあります。
def dense_grad_input(x_input, grad_output, W, b):
"""Calculate the partial derivative of
the loss with respect to the input of the layer
# Arguments
x_input: input of a dense layer - np.array of size `(n_objects, n_in)`
grad_output: partial derivative of the loss functions with
respect to the ouput of the dense layer
np.array of size `(n_objects, n_out)`
W: np.array of size `(n_in, n_out)`
b: np.array of size `(n_out,)`
# Output
the partial derivative of the loss with
respect to the input of the layer
np.array of size `(n_objects, n_in)`
"""
#################
### YOUR CODE ###
#################
return grad_input
#x_input
[[ 0.29682018 0.02620921 0.03910291 0.31660917 0.6809823 0.67731154
0.85846755 0.96218481 0.90590621 0.72424189 0.33797153 0.68878736
0.78965605 0.23509894 0.7241181 0.28966239 0.31927664 0.85477801]
[ 0.9960161 0.4369152 0.89877488 0.78452364 0.22198744 0.04382131
0.4169376 0.69122887 0.25566736 0.44901459 0.50918353 0.8193029
0.29340534 0.46017931 0.64337706 0.63181193 0.81610792 0.45420877]
[ 0.24633573 0.1358581 0.07556498 0.85105726 0.99732196 0.00668041
0.61558841 0.22549151 0.20417495 0.90856472 0.43778948 0.5179694
0.77824586 0.98535274 0.37334145 0.77306608 0.84054839 0.59580074]
[ 0.68575595 0.48426868 0.17377837 0.5779052 0.7824412 0.14172426
0.93237195 0.71980057 0.04890449 0.35121393 0.67403124 0.71114348
0.32314314 0.84770232 0.10081962 0.27920494 0.52890886 0.64462433]
[ 0.35874758 0.96694283 0.374106 0.40640907 0.59441666 0.04155628
0.57434682 0.43011294 0.55868019 0.59398029 0.22563919 0.39157997
0.31804255 0.63898075 0.32462043 0.95516196 0.40595824 0.24739606]]
#grad_output
[[ 0.30650667 0.66195042 0.32518952 0.68266843 0.16748198]
[ 0.87112224 0.66131922 0.03093839 0.61508666 0.21811778]
[ 0.95191614 0.70929627 0.42584023 0.59418774 0.75341628]
[ 0.32523626 0.90275084 0.3625107 0.52354435 0.23991962]
[ 0.89248732 0.55744782 0.02718998 0.82430586 0.73937504]]
#W
[[ 0.8584596 0.28496554 0.6743653 0.81776177 0.28957213]
[ 0.96371309 0.19263171 0.78160551 0.07797744 0.21341943]
[ 0.5191679 0.02631223 0.37672431 0.7439749 0.53042904]
[ 0.1472284 0.46261313 0.18701797 0.17023813 0.63925535]
[ 0.6169004 0.43381192 0.93162705 0.62511267 0.45877614]
[ 0.30612274 0.39457724 0.26087929 0.34826782 0.71235394]
[ 0.66890267 0.70557853 0.48098531 0.76937604 0.10892615]
[ 0.17080091 0.57693496 0.19482135 0.07942299 0.7505965 ]
[ 0.61697062 0.1725569 0.21757211 0.64178749 0.41287085]
[ 0.96790726 0.22636129 0.38378524 0.02240361 0.08083711]
[ 0.67933 0.34274892 0.55247312 0.06602492 0.75212193]
[ 0.00522951 0.49808998 0.83214543 0.46631055 0.48400103]
[ 0.56771735 0.70766078 0.27010417 0.73044053 0.80382 ]
[ 0.12586939 0.18685427 0.66328521 0.84542463 0.7792 ]
[ 0.21744701 0.90146876 0.67373118 0.88915982 0.5605676 ]
[ 0.71208837 0.89978603 0.34720491 0.79784756 0.73914921]
[ 0.48384807 0.10921725 0.81603026 0.82053322 0.45465871]
[ 0.56148353 0.31003923 0.39570321 0.7816182 0.23360955]]
#b
[ 0.10006862 0.36418521 0.56036054 0.32046732 0.57004243]
どちらが損失ですか? – hedgehogues