2017-04-06 3 views
1

私はPythonコードのこのビットを実行しようとしていたし、エラーを回避することができないように無効な引数エラー:Tensorflow:コンボリューション

ここ
tf.nn.conv2d(tf.reshape(x, [5, 5]), tf.reshape(wt, [3, 3]), strides=[1, 1], padding='SAME') 

は、xはtf.Variableである(5,5 )numpy配列であり、wは(3,3)numpy配列からのtf.Variableです。

私が手にエラーがある:tf.nn.conv2dを使用するためには

--------------------------------------------------------------------------- 
InvalidArgumentError      Traceback (most recent call last) 
C:\Anaconda3\lib\site-packages\tensorflow\python\framework\common_shapes.py in _call_cpp_shape_fn_impl(op, input_tensors_needed, input_tensors_as_shapes_needed, debug_python_shape_fn, require_shape_fn) 
    669   node_def_str, input_shapes, input_tensors, input_tensors_as_shapes, 
--> 670   status) 
    671 except errors.InvalidArgumentError as err: 

C:\Anaconda3\lib\contextlib.py in __exit__(self, type, value, traceback) 
    65    try: 
---> 66     next(self.gen) 
    67    except StopIteration: 

C:\Anaconda3\lib\site-packages\tensorflow\python\framework\errors_impl.py in raise_exception_on_not_ok_status() 
    468   compat.as_text(pywrap_tensorflow.TF_Message(status)), 
--> 469   pywrap_tensorflow.TF_GetCode(status)) 
    470 finally: 

InvalidArgumentError: Shape must be rank 4 but is rank 2 for 'Conv2D_19' (op: 'Conv2D') with input shapes: [5,5], [3,3]. 

答えて

0

。入力とフィルタの両方を4Dに変換する必要があります。また、strides1-D of length 4(入力の各サイズのスライディングウィンドウ)にする必要があります。 documentationから取られ、次

Given an input tensor of shape [batch, in_height, in_width, in_channels] and a filter/kernel tensor of shape [filter_height, filter_width, in_channels, out_channels], this op performs the following:

Flattens the filter to a 2-D matrix with shape [filter_height * filter_width * in_channels, output_channels]. Extracts image patches from the input tensor to form a virtual tensor of shape [batch, out_height, out_width, filter_height * filter_width * in_channels]. For each patch, right-multiplies the filter matrix and the image patch vector.

あなたが取ることができます。データのためのtf.reshape(x, [1, 5, 5, 1])、フィルタのtf.reshape(wt, [3, 3, 1, 1])、およびstrides=[1, 1, 1, 1]を。結果は次のようになります。

tf.nn.conv2d(tf.reshape(x, [1, 5, 5, 1]), tf.reshape(wt, [3, 3, 1, 1]), strides=[1, 1, 1, 1], padding='SAME') 
関連する問題