DataLoaderから単一のバッチを取得することは可能ですか? 現在、forループを設定し、バッチを手動で返します。
現在、DataLoaderでこれを行う方法がない場合は、機能の追加に取り組んでいきたいと思います。
next(iter(data_loader))
?
かっこいい、それは私が使っていたものよりずっと良いです。
ありがとう!
その答えは、RAMメモリの線形拡張を使用したトレーニングでメモリリークを引き起こしますが、通常のforループ(およびループ内のまったく同じコード)を常に使用しています:/
:+1:@hyperfraiseへ。 これにより、メモリリークが発生します。
次の(わずかに異なる)コードでのメモリリークの同じ問題:
dataloader_iterator = iter(dataloader)
for i in range(iterations):
try:
X, Y = next(dataloader_iterator)
except:
dataloader_iterator = iter(train_loader)
X, Y = next(dataloader_iterator)
do_backprop(X, Y)
forループの間、メモリ占有率は継続的に増加します。 私はより多くの情報で新しい問題を開くかもしれません(まだ行われていない場合)
これはメモリリークではない可能性がありますが、ループが非常にビジーであり、プロセスを終了するよりも速く生成するという事実です。 DataLoaderイテレータは、非常に短命なオブジェクトを意味するものではありません
私の以前のコメントは正しくありませんでした。 私はリークがコードの他の場所にあることを発見しました(興味のある人のために切り離さずにvarsにぶら下がっていました)。
next(iter(dataloader))を試行すると、「BrokenPipeError:[Errno 32] Brokenpipe」が表示されます
このメソッドを使用して、ループでトレーニングするためのバッチを取得しました。
for i in range(n):
batch = next(iter(data_loader))
データセットの基になる__getitem__
が同じitem
インデックスを取得し続けるように、同じバッチを取得し続けることに気付きました。
これは正常ですか?
@shaibagon
あまり文書化されていませんが、 iter(dataloader)
を実行すると、クラス_DataLoaderIterのオブジェクトが作成され、ループ内で同じオブジェクトがn
回作成され、最初のバッチのみが取得されます。
回避策は、ループの外側に_DataLoaderIterを作成し、それを反復処理することです。 問題は、すべてのバッチが取得されると、_DataLoaderIterがStopIterationエラーを発生させることです。
問題を回避するために、私が現在行っていることは次のとおりです。
dataloader_iterator = iter(dataloader)
for i in range(iterations):
try:
data, target = next(dataloader_iterator)
except StopIteration:
dataloader_iterator = iter(dataloader)
data, target = next(dataloader_iterator)
do_something()
非常に醜いですが、問題なく動作します。
同様の問題がここにあります。このスレッドで提案された解決策が、そこでの人々にも役立つことを願っています。
@ srossi93素晴らしい解決策。 反復サイクルが終了すると、無視される例外が発生することがあります: ConnectionResetError: [Errno 104] Connection reset by peer
。
マルチプロセッシングが原因のようです。 データローダーのnum_workersを0に設定すると、エラーが消えます。 他の解決策はありますか?
thx @ srossi93
たぶん、このコードは少し良いですか?
def inf_train_gen():
while True:
for images, targets in enumerate(dataloader):
yield images, targets
gen = inf_train_gen
for it in range(num_iters):
images, targets = gen.next()
上記のコードを使用すると、データセットはシャッフルされますか?
dataloader_iterator = iter(dataloader)
for i in range(iterations):
try:
X, Y = next(dataloader_iterator)
except:
dataloader_iterator = iter(train_loader)
X, Y = next(dataloader_iterator)
do_backprop(X, Y)
なぜこれはすでにジェネレータ/イテレータではないのですか?
@ Yamin05114小さな例を実行して、iter(dataloader)の結果がリセットされるたびにシャッフルされるかどうかを確認しました。 以下のこの小さなスクリプトを実行すると、印刷ステートメントの小さなセットを見て、注文が実際にシャッフルされていることを確認できます。 これは一般的な証明ではありませんが、iter(train_loader)を呼び出すたびにデータがシャッフルされるという説得力のある証拠です。
import torch
from torch.utils.data import Dataset, DataLoader
dataset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=3)
iterloader = iter(dataloader)
for i in range(0, 12):
try:
batch = next(iterloader)
except StopIteration:
iterloader = iter(dataloader)
batch = next(iterloader)
print("iteration" + str(i))
print(batch)
さらに、 @ shaibagonのエラーを再現できませんでした...以下のコードは(上記で定義したものと同じ変数を使用して)別個のバッチを生成するように見えるため、そこで何が起こったのかわかりません。
for i in range(0, 12):
batch = next(iter(dataloader))
print("iteration: " + str(i))
print(batch)
pytorchからdata.Dataset
を継承するdataset
オブジェクトがある場合は、引数としてidxを使用する__getitem__
メソッドをオーバーライドする必要があります。 したがって、直接アクセスできます。
**some dataset instance called _data_
data=Dataset(**kwargs)
for i in range(10):
data[i]
また
for i in range(10):
data_batch.__getitem__(i)
最も参考になるコメント
next(iter(data_loader))
?