私は音楽生成のためのオートエンコーダーを開発しようとしています。その目的のために、私は音楽的関係を捉える損失関数を開発しようとしています。音楽エンコーディングのスパイラルロス機能
私の現在のアイデアは、システムが異なるオクターブで同じ音符を予測すると、音符が間違っている場合よりも損失が小さくなるはずであるという「螺旋」損失関数です。さらに、BやDからCのような正しい音符に近い音符も小さな損失を持つべきです。これは概念的には、コイルまたはスパイラル上の2点間の距離を求めることで、異なるオクターブ内の同じ音符がコイルに接する線に沿って位置するが、ループ距離によって分離されると考えることができる。
私はPyTorchで作業していますが、私の入力表現は36 x 36のTensorで、行は音符(MIDI範囲48:84、ピアノの中央の3オクターブ)を表し、列はタイムステップ(1列= 1/100秒)。行列の値は0または1のいずれかで、特定の時刻にメモがオンであることを示します。ここで
は、損失の私の現在の実装である:
def SpiralLoss():
def spiral_loss(input, output):
loss = Variable(torch.FloatTensor([0]))
d = 5
r = 10
for i in xrange(input.size()[0]):
for j in xrange(input.size()[3]):
# take along the 1 axis because it's a column vector
inval, inind = torch.max(input[i, :, :, j], 1)
outval, outind = torch.max(output[i, :, :, j], 1)
note_loss = (r*30*(inind%12 - outind%12)).float()
octave_loss = (d*(inind/12 - outind/12)).float()
loss += torch.sqrt(torch.pow(note_loss, 2) + torch.pow(octave_loss, 2))
return loss
return spiral_loss
この損失の問題は、MAX関数が微分可能ではないということです。私はこの損失を差別化する方法を考えることができず、誰かがアイデアや提案を持っているのだろうかと疑問に思っていましたか?
これがこのような投稿の適切な場所であるかどうかはわかりません。そうでない場合は、より良い場所に向かうすべての点について本当に感謝しています。
ありがとうございます!
入力と出力の4つの次元は何ですか? – McLawrence
yup! NxCxHxW。私はMIDIピアノロール表現を使って作業しています。 Nはバッチの数であり、Cは畳み込み層での使用のためのチャネル(私の場合は1)である。 Hはミディノートのディメンション、Wは時間(ピアノロールの高さと幅) – bgenchel
ああ。私はあなたが試してみることができ、それが動作するかどうか私に知らせる1つの可能な解決策を投稿 – McLawrence