J'ai testé en exécutant ces 3 morceaux de code respectivement.
Tout d'abord, init certaines variables et enregistrez.
import tensorflow as tf
sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()
saver.save(sess,'v12.ckpt')
Ensuite, restaurez la session, ajoutez une variable supplémentaire et enregistrez.
import tensorflow as tf
sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
saver = tf.train.Saver()
saver.restore(sess,'v12.ckpt') #works fine here
v3 = tf.Variable(3,name="v3")
sess.run(tf.initialize_variables([v3]))
saver.save(sess,'v123.ckpt')
print v3.eval() #show value without problem
Ensuite, restaurez-les.
import tensorflow as tf
sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
v3 = tf.Variable(3,name="v3")
saver = tf.train.Saver()
saver.restore(sess,'v123.ckpt') #error here
C'est l'erreur :
tensorflow.python.framework.errors.NotFoundError: Tensor name "v3" not found in checkpoint files v123.ckpt [[Node: save/restore_slice_2 = RestoreSlice[dt=DT_INT32, preferred_shard=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/restore_slice_2/tensor_name, save/restore_slice_2/shape_and_slice)]] Caused by op u'save/restore_slice_2'
Quel est le problème ici?
J'utilise la version r0.8 sur Ubuntu16.04
Lorsque vous créez un tf.train.Saver
sans arguments, il utilisera implicitement l'ensemble de variables actuel _au moment de la construction de l'économiseur_ lors de la sauvegarde et de la restauration. Si vous ajoutez une nouvelle variable (par exemple v3
dans votre deuxième bloc de code), vous devez créer un nouveau tf.train.Saver
pour l'enregistrer.
import tensorflow as tf
sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
saver = tf.train.Saver()
saver.restore(sess,'v12.ckpt') #works fine here
v3 = tf.Variable(3,name="v3")
sess.run(tf.initialize_variables([v3]))
saver_with_v3 = tf.train.Saver()
saver_with_v3.save(sess,'v123.ckpt')
Commentaire le plus utile
Lorsque vous créez un
tf.train.Saver
sans arguments, il utilisera implicitement l'ensemble de variables actuel _au moment de la construction de l'économiseur_ lors de la sauvegarde et de la restauration. Si vous ajoutez une nouvelle variable (par exemplev3
dans votre deuxième bloc de code), vous devez créer un nouveautf.train.Saver
pour l'enregistrer.