2017-01-12 12 views
2

これは実際にこの質問にはあてはまりません... CNTK python api - continue training a modelこれらは関連していますが、同じではありません。CNTK python api - クラシファイアのトレーニングを続ける

私は1500エポックのモデルを訓練し、67%ほどの平均損失を得ていました。

def Create_Trainer(train_reader, minibatch_size, epoch_size, checkpoint_path=None, distributed_after=INFINITE_SAMPLES): 
#Create Model with Params 
lr_per_minibatch = learning_rate_schedule(
    [0.01] * 10 + [0.003] * 10 + [0.001], UnitType.minibatch, epoch_size) 
momentum_time_constant = momentum_as_time_constant_schedule(
    -minibatch_size/np.log(0.9)) 
l2_reg_weight = 0.0001 
input_var = input_variable((num_channels, image_height, image_width)) 
label_var = input_variable((num_classes)) 
feature_scale = 1.0/256.0 
input_var_norm = element_times(feature_scale, input_var) 
z = create_model(input_var_norm, num_classes) 
#Create Error Functions 
if(checkpoint_path): 
    print('Loaded Checkpoint!') 
    z.load_model(checkpoint_path) 
ce = cross_entropy_with_softmax(z, label_var) 
pe = classification_error(z, label_var)  

#Create Learner  
learner = momentum_sgd(z.parameters, 
         lr=lr_per_minibatch, momentum=momentum_time_constant, 
         l2_regularization_weight=l2_reg_weight) 
if(distributed_after != INFINITE_SAMPLES): 
    learner = distributed.data_parallel_distributed_learner(
     learner = learner, 
     num_quantization_bits = 1, 
     distributed_after = distributed_after 
    ) 
input_map = { 
    input_var: train_reader.streams.features, 
    label_var: train_reader.streams.labels 
} 
return Trainer(z, ce, pe, learner), input_map 

予告コードの行::(checkpoint_path)の場合:およそ半押しを私は、次のように私はコード化されてきたトレーニングを、継続したいです。私はこの機能によって保存され、以前の訓練、から.dnnファイルをロード

...

if current_epoch % checkpoint_frequency == 0: 
      trainer.save_checkpoint(os.path.join(checkpoint_path + "_{}.dnn".format(current_epoch))) 

は、これは実際に.dnnと.dnn.ckpファイルを生成します。明らかに、私はload_modelで.dnnファイルのみを読み込みます。

トレーニングを再開してモデルを読み込むと、ネットワークアーキテクチャがロードされているように見えますが、重みはないと思われますか?これを行う正しい方法論は何ですか?

感謝!

答えて

4

代わりにtrainer.restore_from_checkpointを使用する必要があります。これは、トレーナーと学習者を再作成する必要があります。

すぐに、簡単な方法でシームレスな復元を可能にし、トレーナー/ミニバッチ/分散状態の世話をするトレーニングセッションになります。

重要なこと:pythonスクリプトでは、ノードを作成するネットワーク構造と順序は、チェックポイントを作成してから復元する時点で同じでなければなりません。

関連する問題