版本:每晚为带有 GPU 的 Python2 预构建(刚刚)
我期待以下代码打印“10.0”3 次,但 session.run 卡在所有分叉进程中。
import tensorflow as tf
import multiprocessing as mp
import os
class Worker(mp.Process):
def __init__(self, gid):
self.gid = gid
super(Worker, self).__init__()
def run(self):
G = tf.Graph()
with G.as_default():
x = tf.placeholder(tf.float32, shape=[])
y = x * 2
sess = tf.Session()
print sess.run(y, feed_dict={x: 5})
G = tf.Graph()
with G.as_default():
sess = tf.Session()
with sess.as_default():
x = tf.placeholder(tf.float32, shape=[])
y = x * 2
print sess.run(y, feed_dict={x: 5})
procs = [Worker(k) for k in range(2)]
for p in procs: p.start()
for p in procs: p.join()
删除主进程中的图形/会话将解决问题。 所以好像一旦有会话,我们就不能使用 fork 了?
无论有没有 GPU,问题都存在。
注意:此代码不会正常终止。 在 master 退出后,您可能需要手动终止分叉进程。
进程内会话(即没有参数的tf.Session()
)不是设计为fork()
。 如果要在多个进程之间共享一组设备,请在一个进程中创建tf.train.Server
,并在其他进程中创建连接到该服务器的会话(使用tf.Session("grpc://...")
)。
@mrry 这是否意味着有一种方法可以用tf.Session(args)
创建fork
安全的tf.Session
tf.Session(args)
?
@mavenlin
tf.Session
的原型是
tf.Session.__init__(target='', graph=None, config=None)
这里target
指的是要连接的执行引擎。 也就是说,必须在另一个进程中以分布式模式运行工作会话,并且带有参数的tf.Session
仍然不是fork()
。
悲伤的消息。
最有用的评论
@mrry 这是否意味着有一种方法可以用
tf.Session(args)
创建fork
安全的tf.Session
tf.Session(args)
?