Tensorflow: 新しく追加された変数は保存できません

作成日 2016年05月24日  ·  1コメント  ·  ソース: tensorflow/tensorflow

これらの3つのコードをそれぞれ実行してテストしました。

まず、いくつかの変数を初期化して保存します。

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()
saver.save(sess,'v12.ckpt')

次に、セッションを復元し、変数をもう1つ追加して、保存します。

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
saver = tf.train.Saver()
saver.restore(sess,'v12.ckpt')  #works fine here

v3 = tf.Variable(3,name="v3")
sess.run(tf.initialize_variables([v3]))

saver.save(sess,'v123.ckpt')

print v3.eval() #show value without problem

次に、それらを復元します。

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
v3 = tf.Variable(3,name="v3")

saver = tf.train.Saver()
saver.restore(sess,'v123.ckpt') #error here

これはエラーです:

tensorflow.python.framework.errors.NotFoundError: Tensor name "v3" not found in checkpoint files v123.ckpt [[Node: save/restore_slice_2 = RestoreSlice[dt=DT_INT32, preferred_shard=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/restore_slice_2/tensor_name, save/restore_slice_2/shape_and_slice)]] Caused by op u'save/restore_slice_2'

ここでの問題は何ですか?

Ubuntu16.04でバージョンr0.8を使用しています

最も参考になるコメント

引数なしでtf.train.Saverを作成すると、保存および復元時に、_セーバーの構築時に_現在の変数のセットが暗黙的に使用されます。 新しい変数を追加する場合(たとえば、2番目のコードブロックにv3 )、それを保存するために新しいtf.train.Saverを作成する必要があります。

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
saver = tf.train.Saver()
saver.restore(sess,'v12.ckpt')  #works fine here

v3 = tf.Variable(3,name="v3")
sess.run(tf.initialize_variables([v3]))

saver_with_v3 = tf.train.Saver()
saver_with_v3.save(sess,'v123.ckpt')

>すべてのコメント

引数なしでtf.train.Saverを作成すると、保存および復元時に、_セーバーの構築時に_現在の変数のセットが暗黙的に使用されます。 新しい変数を追加する場合(たとえば、2番目のコードブロックにv3 )、それを保存するために新しいtf.train.Saverを作成する必要があります。

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
saver = tf.train.Saver()
saver.restore(sess,'v12.ckpt')  #works fine here

v3 = tf.Variable(3,name="v3")
sess.run(tf.initialize_variables([v3]))

saver_with_v3 = tf.train.Saver()
saver_with_v3.save(sess,'v123.ckpt')
このページは役に立ちましたか?
0 / 5 - 0 評価