是否可以从 DataLoader 获取单个批次? 目前,我设置了一个 for 循环并手动返回一个批次。
如果目前无法使用 DataLoader 执行此操作,我将很乐意添加该功能。
next(iter(data_loader))
?
酷,那比我以前用的好多了。
谢谢!
这个答案在我的训练中通过 RAM 内存的线性增加引起了内存泄漏,同时使用常规 for 循环(以及循环中完全相同的代码)不断占用:/
:+1: @hyperfraise。 这会造成内存泄漏。
以下(略有不同)代码存在相同的内存泄漏问题:
dataloader_iterator = iter(dataloader)
for i in range(iterations):
try:
X, Y = next(dataloader_iterator)
except:
dataloader_iterator = iter(train_loader)
X, Y = next(dataloader_iterator)
do_backprop(X, Y)
在 for 循环期间内存占用不断增加。 我可能会打开一个包含更多信息的新问题(如果尚未完成)
这可能不是内存泄漏,而仅仅是因为您的循环非常忙于生成进程,速度比我们终止它们的速度还要快。 DataLoader 迭代器并不是生命周期很短的对象
我之前的评论是不正确的。 我发现泄漏在代码的其他地方(对于那些好奇的人,我一直挂在 vars 上而没有分离)。
尝试下一个(iter(dataloader))时得到“BrokenPipeError:[Errno 32] Broken pipe”
我使用这种方法来检索批次以进行循环训练:
for i in range(n):
batch = next(iter(data_loader))
我注意到我不断获得相同的批次,就像数据集的底层__getitem__
不断获得相同的item
索引。
这是正常的吗?
@shaibagon
它没有很好地记录,但是当您执行iter(dataloader)
时,您会创建一个_DataLoaderIter类的对象,并且在循环中,您将创建相同的对象n
次并仅检索第一批。
一种解决方法是在循环外创建一个 _DataLoaderIter 并对其进行迭代。 问题是一旦检索到所有批次,_DataLoaderIter 将引发StopIteration 错误。
为避免出现问题,我目前正在做的事情如下:
dataloader_iterator = iter(dataloader)
for i in range(iterations):
try:
data, target = next(dataloader_iterator)
except StopIteration:
dataloader_iterator = iter(dataloader)
data, target = next(dataloader_iterator)
do_something()
它非常难看,但效果很好。
可以在这里找到类似的问题我希望这个线程中提出的解决方案也可以帮助那里的人。
@srossi93不错的解决方案。 有时,当迭代周期结束时,我会得到一个被忽略的异常: ConnectionResetError: [Errno 104] Connection reset by peer
。
似乎是由多处理引起的。 将数据加载器上的 num_workers 设置为 0 会使错误消失。 还有其他解决方案吗?
谢谢@srossi93
也许这段代码好一点?
def inf_train_gen():
while True:
for images, targets in enumerate(dataloader):
yield images, targets
gen = inf_train_gen
for it in range(num_iters):
images, targets = gen.next()
如果我使用上面提供的代码,数据集会洗牌吗?
dataloader_iterator = iter(dataloader)
for i in range(iterations):
try:
X, Y = next(dataloader_iterator)
except:
dataloader_iterator = iter(train_loader)
X, Y = next(dataloader_iterator)
do_backprop(X, Y)
为什么这还不是生成器/迭代器?
@Yamin05114我跑了一个小例子来看看 iter(dataloader) 的结果是否在每次重置时都被洗牌。 如果您在下面运行这个小脚本,您可以查看一小部分打印语句,以自己确认订单确实被打乱了。 这不是一般的证明,但它是令人信服的证据,表明每次调用 iter(train_loader) 时数据都会被打乱。
import torch
from torch.utils.data import Dataset, DataLoader
dataset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=3)
iterloader = iter(dataloader)
for i in range(0, 12):
try:
batch = next(iterloader)
except StopIteration:
iterloader = iter(dataloader)
batch = next(iterloader)
print("iteration" + str(i))
print(batch)
此外,我无法重现@shaibagon的错误......下面的代码似乎产生了不同的批次(使用与上面定义的相同的变量),所以不确定那里发生了什么。
for i in range(0, 12):
batch = next(iter(dataloader))
print("iteration: " + str(i))
print(batch)
如果你有一个从 pytorch 继承data.Dataset
的dataset
对象,它必须覆盖__getitem__
方法,该方法使用 idx 作为参数。 因此,您可以直接访问它:
**some dataset instance called _data_
data=Dataset(**kwargs)
for i in range(10):
data[i]
要么
for i in range(10):
data_batch.__getitem__(i)
最有用的评论
next(iter(data_loader))
?