0

私は深い学習の初心者で、アルゴリズムの仕組みを理解し、JavaScriptを使って記述しています。今、私はTensorflowのようなconv2dのJavaScript実装に取り​​組んでいます。さまざまなフィルタの数を扱う方法を誤解しています.1つの出力フィルタと複数の出力に成功しましたが、複数のフィルタ入力で操作を生成する方法が混乱しています。 32 - ここ> 64Tensorflow、conv2d、filters

ndarray を使用してコードの例です。

const outCount = 32 // count of inputs filters 
const inCount = 1 // count of output features 
const filterSize = 3 
const stride = 1 
const inShape = [1, 10, 10, outCount] 
const outShape = [ 
    1, 
    Math.ceil((inShape[1] - filterSize + 1)/stride), 
    Math.ceil((inShape[2] - filterSize + 1)/stride), 
    outCount 
]; 
const filters = ndarray([], [filterSize, filterSize, inCount, outCount]) 

const conv2d = (input) => { 
    const result = ndarray(outShape) 
    // for each output feature 

    for (let fo = 0; fo < outCount; fo += 1) { 
    for (let x = 0; x < outShape[1]; x += 1) { 
     for (let y = 0; y < outShape[2]; y += 1) { 
     const fragment = ndarray([], [filterSize, filterSize]); 
     const filter = ndarray([], [filterSize, filterSize]); 

     // agregate fragment of image and filter 
     for (let fx = 0; fx < filterSize; fx += 1) { 
     for (let fy = 0; fy < filterSize; fy += 1) { 
      const dx = (x * stride) + fx; 
      const dy = (y * stride) + fy; 

      fragment.data.push(input.get(0, dx, dy, 0)); 
      filter.data.push(filters.get(fx, fy, 0, fo)); 
     } 
     } 

     // calc dot product of filter and image fragment 
     result.set(0, x, y, fo, dot(filter, fragment)); 
     } 
    } 
    } 

    return result 
} 

テストのために真とそのアルゴリズムのソースは正しいが、1 -> Nで動作するように私はTenforflowを使用しています。しかし、私の質問のような入力値で複数のフィルタのサポートを追加する方法。

誰かがTensorflowとの互換性を高めるためにこのアルゴリズムを変更する方法を説明できますか?tf.nn.conv2d ありがとうございます。

答えて

0

forループを追加する必要があります。すべての入力形状と寸法を指定していないので、実際には正確に書くのは難しいですが、このようになります。

// agregate fragment of image and filter 
    for (let fx = 0; fx < filterSize; fx += 1) { 
    for (let fy = 0; fy < filterSize; fy += 1) { 
     //addition 
     for (let ch = 0; ch < input.get_channels) { 
     const dx = (x * stride) + fx; 
     const dy = (y * stride) + fy; 

     fragment.data.push(input.get(0, dx, dy, ch)); 
     filter.data.push(filters.get(fx, fy, ch, fo)); 
     } 
    } 
    } 
+0

本当にありがとうございます。 – XMANX

関連する問題