私はpytorchを使ってTextGANを "複製"しようとしています。私はpytorchを初めて使っています。私の現在の懸念はL_G(式7ページ3)を複製することであり、ここでは私の現在のコードです:それは作品Pytorchで上限JSD損失を実装する方法は?
class JSDLoss(nn.Module):
def __init__(self):
super(JSDLoss,self).__init__()
def forward(self, batch_size, f_real, f_synt):
assert f_real.size()[1] == f_synt.size()[1]
f_num_features = f_real.size()[1]
identity = autograd.Variable(torch.eye(f_num_features)*0.1, requires_grad=False)
if use_cuda:
identity = identity.cuda(gpu)
f_real_mean = torch.mean(f_real, 0, keepdim=True)
f_synt_mean = torch.mean(f_synt, 0, keepdim=True)
dev_f_real = f_real - f_real_mean.expand(batch_size,f_num_features)
dev_f_synt = f_synt - f_synt_mean.expand(batch_size,f_num_features)
f_real_xx = torch.mm(torch.t(dev_f_real), dev_f_real)
f_synt_xx = torch.mm(torch.t(dev_f_synt), dev_f_synt)
cov_mat_f_real = (f_real_xx/batch_size) - torch.mm(f_real_mean, torch.t(f_real_mean)) + identity
cov_mat_f_synt = (f_synt_xx/batch_size) - torch.mm(f_synt_mean, torch.t(f_synt_mean)) + identity
cov_mat_f_real_inv = torch.inverse(cov_mat_f_real)
cov_mat_f_synt_inv = torch.inverse(cov_mat_f_synt)
temp1 = torch.trace(torch.add(torch.mm(cov_mat_f_synt_inv, cov_mat_f_real), torch.mm(cov_mat_f_real_inv, cov_mat_f_synt)))
temp1 = temp1.view(1,1)
temp2 = torch.mm(torch.mm((f_synt_mean - f_real_mean), (cov_mat_f_synt_inv + cov_mat_f_real_inv)), torch.t(f_synt_mean - f_real_mean))
loss_g = torch.add(temp1, temp2).mean()
return loss_g
。しかし、私はそれがカスタムロスを作成する方法ではないという疑念を持っています。どんな種類の助けでも大歓迎です!事前に感謝:)
コメント@JMAに感謝します。私はかなり多くのことをすでに知っていましたが、私の質問は実際には正しい方程式を使うのか、紙の方程式を反映するのかにもっと焦点を当てました。あなたがこのことについて私を助けることができるなら、それは全世界を意味するでしょう:) –
あなたはまだこれを確認する助けが必要ですか? – JMA
はい、ディスカッションを電子メールに移行する場合もOKです –