Pytorch: Получите один пакет из DataLoader без повторения

Созданный на 26 июн. 2017  ·  18Комментарии  ·  Источник: pytorch/pytorch

Можно ли получить одну партию из DataLoader? В настоящее время я настраиваю цикл for и возвращаю пакет вручную.
Если в настоящее время нет способа сделать это с помощью DataLoader, я был бы рад поработать над добавлением этой функциональности.

Самый полезный комментарий

next(iter(data_loader)) ?

Все 18 Комментарий

next(iter(data_loader)) ?

Круто, это намного лучше, чем то, что я использовал.
Спасибо!

Этот ответ провоцирует утечку памяти в моем обучении с линейным увеличением оперативной памяти, при постоянном занятии обычным циклом for (и точно таким же кодом в цикле):/

:+1: в @hyperfraise. это создает утечку памяти.

Та же проблема утечки памяти со следующим (немного другим) кодом:

dataloader_iterator = iter(dataloader)
for i in range(iterations):     
    try:
        X, Y = next(dataloader_iterator)
    except:
        dataloader_iterator = iter(train_loader)
        X, Y = next(dataloader_iterator)
    do_backprop(X, Y)

Занятость памяти постоянно увеличивается во время цикла for. Я мог бы открыть новый выпуск с дополнительной информацией (если еще не сделано)

Это может быть не утечка памяти, а просто тот факт, что ваш цикл чрезвычайно занят созданием процессов быстрее, чем мы можем их даже завершить. Итераторы DataLoader не должны быть очень недолговечными объектами.

Мой предыдущий комментарий был неверным. Я обнаружил, что утечка была где-то в другом месте кода (для тех, кому любопытно, я цеплялся за vars, не отрываясь).

получить «BrokenPipeError: [Errno 32] Broken pipe» при попытке выполнить следующий (iter (dataloader))

Я использовал этот метод для извлечения пакетов для обучения в цикле:

    for i in range(n):
       batch = next(iter(data_loader))

Я заметил, что продолжаю получать одну и ту же партию, как базовый набор данных __getitem__ продолжает получать один и тот же индекс item .
Это нормально?

@шайбагон
Это не очень хорошо документировано, но когда вы делаете iter(dataloader) , вы создаете объект класса _DataLoaderIter , и в цикле вы создаете один и тот же объект n раз и извлекаете только первую партию.
Обходной путь — создать _DataLoaderIter вне цикла и перебрать его. Проблема в том, что после получения всех пакетов _DataLoaderIter выдаст ошибку StopIteration .

Чтобы избежать проблем, я сейчас делаю следующее:

    dataloader_iterator = iter(dataloader)
    for i in range(iterations):
        try:
            data, target = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(dataloader)
            data, target = next(dataloader_iterator)
        do_something()

Это очень уродливо, но работает просто отлично.

аналогичную проблему можно найти здесь, я надеюсь, что предлагаемые решения в этой теме также могут помочь людям.

@ srossi93 хорошее решение. Иногда я получаю игнорируемое исключение, когда цикл итераций заканчивается: ConnectionResetError: [Errno 104] Connection reset by peer .

По-видимому, вызвано многопроцессорностью. Установка num_workers в загрузчике данных на 0 приводит к исчезновению ошибки. Любые другие решения?

спасибо @srossi93

Может быть, этот код немного лучше?

def inf_train_gen():
    while True:
        for images, targets in enumerate(dataloader):
            yield images, targets
gen = inf_train_gen
for it in range(num_iters):
    images, targets = gen.next()    

Будет ли набор данных перемешиваться, если я использую приведенный выше код?
dataloader_iterator = iter(dataloader) for i in range(iterations): try: X, Y = next(dataloader_iterator) except: dataloader_iterator = iter(train_loader) X, Y = next(dataloader_iterator) do_backprop(X, Y)

почему это уже не генератор/итератор?

@ Yamin05114 Yamin05114 Я запустил небольшой пример, чтобы увидеть, перемешивается ли результат iter (dataloader) каждый раз при сбросе. Если вы запустите этот небольшой скрипт, приведенный ниже, вы сможете просмотреть небольшой набор операторов печати, чтобы убедиться в том, что порядок действительно перемешан. Это не доказательство в целом, но убедительное свидетельство того, что данные перемешиваются каждый раз, когда мы вызываем iter(train_loader).

import torch
from torch.utils.data import Dataset, DataLoader

dataset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=3)
iterloader = iter(dataloader)

for i in range(0, 12):

    try:
        batch = next(iterloader)
    except StopIteration:
        iterloader = iter(dataloader)
        batch = next(iterloader)

    print("iteration" + str(i))
    print(batch)

Кроме того, я не смог воспроизвести ошибку @shaibagon ... приведенный ниже код, похоже, создает отдельные партии (с использованием тех же переменных, что и определенные выше), поэтому не уверен, что там произошло.

for i in range(0, 12):
    batch = next(iter(dataloader))
    print("iteration: " + str(i))
    print(batch)

Если у вас есть объект dataset , который наследует data.Dataset от pytorch, он должен переопределить метод __getitem__ , который использует idx в качестве аргумента. Поэтому вы можете получить к нему прямой доступ:

**some dataset instance called _data_
data=Dataset(**kwargs)
for i in range(10):
     data[i]

или

for i in range(10):
     data_batch.__getitem__(i)
Была ли эта страница полезной?
0 / 5 - 0 рейтинги