Pytorch: Get a single batch from DataLoader without iterating

Created on 26 Jun 2017  ·  18Comments  ·  Source: pytorch/pytorch

Is it possible to get a single batch from a DataLoader? Currently, I setup a for loop and return a batch manually.
If there isn't a way to do this with the DataLoader currently, I would be happy to work on adding the functionality.

Most helpful comment

next(iter(data_loader)) ?

All 18 comments

next(iter(data_loader)) ?

Cool, thats a lot better than what I had been using.
Thanks!

That answer provokes a memory leak in my training with linear augmentation of RAM memory, while constant occupation with a regular for loop (and exactly same code in the loop) :/

:+1: to @hyperfraise. this creates a memory leak.

Same problem of memory leak with the following (slightly different) code:

dataloader_iterator = iter(dataloader)
for i in range(iterations):     
    try:
        X, Y = next(dataloader_iterator)
    except:
        dataloader_iterator = iter(train_loader)
        X, Y = next(dataloader_iterator)
    do_backprop(X, Y)

Memory occupation continuously increases during the for-loop. I might open a new issue with more information (if not done yet)

This might not be a memory leak but simply the fact that your loop is extremely busy spawning processes faster than we can even terminate them. DataLoader iterators are not meant to be very short lived objects

My previous comment was incorrect. I discovered that the leak was elsewhere in the code (I was hanging on to vars without detaching for those who are curious).

get "BrokenPipeError: [Errno 32] Broken pipe" when trying next(iter(dataloader))

I used this method to retrieve batches for training in a loop:

    for i in range(n):
       batch = next(iter(data_loader))

I noticed I keep getting the same batch, like the underlying __getitem__ of the dataset keeps getting the same item index.
Is this normal?

@shaibagon
It's not documented very well but when you do iter(dataloader) you create an object of class _DataLoaderIter and, in the loop, you'll create same object n times and retrieve the first batch only.
A workaround is to create a _DataLoaderIter outside the loop and iterate over it. The problem is that once all batches are retrieved, _DataLoaderIter will raise a StopIteration error.

To avoid problems, what I'm currently doing is the following:

    dataloader_iterator = iter(dataloader)
    for i in range(iterations):
        try:
            data, target = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(dataloader)
            data, target = next(dataloader_iterator)
        do_something()

It's very ugly but it works just fine.

similar issue can be found here I hope proposed solutions in this thread can help people there as well.

@srossi93 nice solution. Sometimes I get an ignored exception when the iterations cycle ends: ConnectionResetError: [Errno 104] Connection reset by peer.

Appears to be caused by multiprocessing. Setting the num_workers on the dataloader to 0 makes the error disappear. Any other solutions?

thx @srossi93

Maybe this code is a little better?

def inf_train_gen():
    while True:
        for images, targets in enumerate(dataloader):
            yield images, targets
gen = inf_train_gen
for it in range(num_iters):
    images, targets = gen.next()    

Will the dataset shuffle if I use the code provided above?
dataloader_iterator = iter(dataloader) for i in range(iterations): try: X, Y = next(dataloader_iterator) except: dataloader_iterator = iter(train_loader) X, Y = next(dataloader_iterator) do_backprop(X, Y)

why is this not a generator/iterator already?

@Yamin05114 I ran a small example to see if the result of iter(dataloader) is shuffled every time it is reset. If you run this little script below, you can look at the small set of print statements to confirm for yourself that the order is indeed shuffled. This is not a proof in general, but it is convincing evidence that the data is shuffled everytime we call iter(train_loader).

import torch
from torch.utils.data import Dataset, DataLoader

dataset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=3)
iterloader = iter(dataloader)

for i in range(0, 12):

    try:
        batch = next(iterloader)
    except StopIteration:
        iterloader = iter(dataloader)
        batch = next(iterloader)

    print("iteration" + str(i))
    print(batch)

In addition, I could not reproduce the error of @shaibagon ...the below code seems to produce distinct batches (using the same variables as defined above), so not sure what happened there.

for i in range(0, 12):
    batch = next(iter(dataloader))
    print("iteration: " + str(i))
    print(batch)

If you have a dataset object that inherits data.Dataset from pytorch, it must override __getitem__ method, which uses idx as an argument. Therefore you can access it directly:

**some dataset instance called _data_
data=Dataset(**kwargs)
for i in range(10):
     data[i]

or

for i in range(10):
     data_batch.__getitem__(i)
Was this page helpful?
0 / 5 - 0 ratings