Tensorflow: ์ƒˆ๋กœ ์ถ”๊ฐ€๋œ ๋ณ€์ˆ˜๋ฅผ ์ €์žฅํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.

์— ๋งŒ๋“  2016๋…„ 05์›” 24์ผ  ยท  1๋…ผํ‰  ยท  ์ถœ์ฒ˜: tensorflow/tensorflow

์ด 3๊ฐœ์˜ ์ฝ”๋“œ๋ฅผ ๊ฐ๊ฐ ์‹คํ–‰ํ•˜์—ฌ ํ…Œ์ŠคํŠธํ–ˆ์Šต๋‹ˆ๋‹ค.

๋จผ์ € ์ผ๋ถ€ ๋ณ€์ˆ˜๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๊ณ  ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()
saver.save(sess,'v12.ckpt')

๊ทธ๋Ÿฐ ๋‹ค์Œ ์„ธ์…˜์„ ๋ณต์›ํ•˜๊ณ  ๋ณ€์ˆ˜๋ฅผ ํ•˜๋‚˜ ๋” ์ถ”๊ฐ€ํ•˜๊ณ  ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
saver = tf.train.Saver()
saver.restore(sess,'v12.ckpt')  #works fine here

v3 = tf.Variable(3,name="v3")
sess.run(tf.initialize_variables([v3]))

saver.save(sess,'v123.ckpt')

print v3.eval() #show value without problem

๊ทธ๋Ÿฐ ๋‹ค์Œ ๋ณต์›ํ•˜์‹ญ์‹œ์˜ค.

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
v3 = tf.Variable(3,name="v3")

saver = tf.train.Saver()
saver.restore(sess,'v123.ckpt') #error here

์ด๊ฒƒ์€ ์˜ค๋ฅ˜์ž…๋‹ˆ๋‹ค:

tensorflow.python.framework.errors.NotFoundError: Tensor name "v3" not found in checkpoint files v123.ckpt [[Node: save/restore_slice_2 = RestoreSlice[dt=DT_INT32, preferred_shard=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/restore_slice_2/tensor_name, save/restore_slice_2/shape_and_slice)]] Caused by op u'save/restore_slice_2'

์—ฌ๊ธฐ์„œ ๋ฌธ์ œ๊ฐ€ ๋ฌด์—‡์ž…๋‹ˆ๊นŒ?

Ubuntu16.04์—์„œ ๋ฒ„์ „ r0.8์„ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

๊ฐ€์žฅ ์œ ์šฉํ•œ ๋Œ“๊ธ€

์ธ์ˆ˜ ์—†์ด tf.train.Saver ๋ฅผ ์ƒ์„ฑํ•˜๋ฉด ์ €์žฅํ•˜๊ณ  ๋ณต์›ํ•  ๋•Œ _Saver ์ƒ์„ฑ ์‹œ_ ํ˜„์žฌ ๋ณ€์ˆ˜ ์„ธํŠธ๋ฅผ ์•”์‹œ์ ์œผ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ƒˆ ๋ณ€์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒฝ์šฐ(์˜ˆ: ๋‘ ๋ฒˆ์งธ ์ฝ”๋“œ ๋ธ”๋ก์— v3 ), ์ด๋ฅผ ์ €์žฅํ•˜๋ ค๋ฉด ์ƒˆ tf.train.Saver ๋ฅผ ์ƒ์„ฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
saver = tf.train.Saver()
saver.restore(sess,'v12.ckpt')  #works fine here

v3 = tf.Variable(3,name="v3")
sess.run(tf.initialize_variables([v3]))

saver_with_v3 = tf.train.Saver()
saver_with_v3.save(sess,'v123.ckpt')

>๋ชจ๋“  ๋Œ“๊ธ€

์ธ์ˆ˜ ์—†์ด tf.train.Saver ๋ฅผ ์ƒ์„ฑํ•˜๋ฉด ์ €์žฅํ•˜๊ณ  ๋ณต์›ํ•  ๋•Œ _Saver ์ƒ์„ฑ ์‹œ_ ํ˜„์žฌ ๋ณ€์ˆ˜ ์„ธํŠธ๋ฅผ ์•”์‹œ์ ์œผ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ƒˆ ๋ณ€์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒฝ์šฐ(์˜ˆ: ๋‘ ๋ฒˆ์งธ ์ฝ”๋“œ ๋ธ”๋ก์— v3 ), ์ด๋ฅผ ์ €์žฅํ•˜๋ ค๋ฉด ์ƒˆ tf.train.Saver ๋ฅผ ์ƒ์„ฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

import tensorflow as tf

sess = tf.InteractiveSession()
v1 = tf.Variable(1,name="v1")
v2 = tf.Variable(2,name="v2")
saver = tf.train.Saver()
saver.restore(sess,'v12.ckpt')  #works fine here

v3 = tf.Variable(3,name="v3")
sess.run(tf.initialize_variables([v3]))

saver_with_v3 = tf.train.Saver()
saver_with_v3.save(sess,'v123.ckpt')
์ด ํŽ˜์ด์ง€๊ฐ€ ๋„์›€์ด ๋˜์—ˆ๋‚˜์š”?
0 / 5 - 0 ๋“ฑ๊ธ‰