Tensorflow: 无法保存新添加的变量

创建于 2016-05-24  ·  1评论  ·  资料来源: tensorflow/tensorflow

我分别通过运行这 3 段代码进行了测试。

首先,初始化一些变量并保存。

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()
saver.save(sess,'v12.ckpt')

然后,恢复会话,再添加一个变量,然后保存。

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
saver = tf.train.Saver()
saver.restore(sess,'v12.ckpt')  #works fine here

v3 = tf.Variable(3,name="v3")
sess.run(tf.initialize_variables([v3]))

saver.save(sess,'v123.ckpt')

print v3.eval() #show value without problem

然后,恢复它们。

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
v3 = tf.Variable(3,name="v3")

saver = tf.train.Saver()
saver.restore(sess,'v123.ckpt') #error here

这是错误:

tensorflow.python.framework.errors.NotFoundError: Tensor name "v3" not found in checkpoint files v123.ckpt [[Node: save/restore_slice_2 = RestoreSlice[dt=DT_INT32, preferred_shard=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/restore_slice_2/tensor_name, save/restore_slice_2/shape_and_slice)]] Caused by op u'save/restore_slice_2'

这里有什么问题?

我在 Ubuntu16.04 上使用版本 r0.8

最有用的评论

当你创建一个没有参数的tf.train.Saver时,它会在保存和恢复时隐式使用当前的变量集_在 Saver 构建时_。 如果您添加一个新变量(例如在您的第二个代码块中的v3 ),您必须创建一个新的tf.train.Saver来保存它。

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
saver = tf.train.Saver()
saver.restore(sess,'v12.ckpt')  #works fine here

v3 = tf.Variable(3,name="v3")
sess.run(tf.initialize_variables([v3]))

saver_with_v3 = tf.train.Saver()
saver_with_v3.save(sess,'v123.ckpt')

>所有评论

当你创建一个没有参数的tf.train.Saver时,它会在保存和恢复时隐式使用当前的变量集_在 Saver 构建时_。 如果您添加一个新变量(例如在您的第二个代码块中的v3 ),您必须创建一个新的tf.train.Saver来保存它。

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
saver = tf.train.Saver()
saver.restore(sess,'v12.ckpt')  #works fine here

v3 = tf.Variable(3,name="v3")
sess.run(tf.initialize_variables([v3]))

saver_with_v3 = tf.train.Saver()
saver_with_v3.save(sess,'v123.ckpt')
此页面是否有帮助?
0 / 5 - 0 等级