バージョン:GPUを搭載したPython2用に毎晩ビルド済み(ちょうど今)
次のコードが「10.0」を3回出力することを期待していますが、session.runがすべてのフォークされたプロセスでスタックしました。
import tensorflow as tf
import multiprocessing as mp
import os
class Worker(mp.Process):
def __init__(self, gid):
self.gid = gid
super(Worker, self).__init__()
def run(self):
G = tf.Graph()
with G.as_default():
x = tf.placeholder(tf.float32, shape=[])
y = x * 2
sess = tf.Session()
print sess.run(y, feed_dict={x: 5})
G = tf.Graph()
with G.as_default():
sess = tf.Session()
with sess.as_default():
x = tf.placeholder(tf.float32, shape=[])
y = x * 2
print sess.run(y, feed_dict={x: 5})
procs = [Worker(k) for k in range(2)]
for p in procs: p.start()
for p in procs: p.join()
マスタープロセスでグラフ/セッションを削除すると、問題が解決します。 それで、セッションがあると、フォークを使用できないように見えますか?
この問題は、GPUがある場合とない場合に存在します。
注:このコードは正常に終了しません。 マスターが終了した後、フォークされたプロセスを手動で強制終了する必要があります。
インプロセスセッション(つまり、引数のないtf.Session()
)は、 fork()
安全になるようには設計されていません。 複数のプロセス間でデバイスのセットを共有する場合は、1つのプロセスでtf.train.Server
を作成し、他のプロセスでそのサーバーに接続するセッションを作成します( tf.Session("grpc://...")
)。
@mrryないことがある意味では、作成する方法ですfork
安全tf.Session
とtf.Session(args)
?
@mavenlin
tf.Session
のプロトタイプは
tf.Session.__init__(target='', graph=None, config=None)
ここで、 target
は、接続する実行エンジンを指します。 つまり、分散モードで別のプロセスで作業セッションを実行する必要があり、引数付きのtf.Session
はまだfork()
safeではありません。
悲しいニュース。
最も参考になるコメント
@mrryないことがある意味では、作成する方法です
fork
安全tf.Session
とtf.Session(args)
?