Pytorch: Obtenga un solo lote de DataLoader sin iterar

Creado en 26 jun. 2017  ·  18Comentarios  ·  Fuente: pytorch/pytorch

¿Es posible obtener un solo lote de un DataLoader? Actualmente, configuro un bucle for y devuelvo un lote manualmente.
Si no hay una manera de hacer esto con el DataLoader actualmente, estaría feliz de trabajar para agregar la funcionalidad.

Comentario más útil

next(iter(data_loader)) ?

Todos 18 comentarios

next(iter(data_loader)) ?

Genial, eso es mucho mejor que lo que había estado usando.
¡Gracias!

Esa respuesta provoca una fuga de memoria en mi entrenamiento con el aumento lineal de la memoria RAM, mientras que la ocupación constante con un bucle for regular (y exactamente el mismo código en el bucle) :/

:+1: a @hyperfraise. esto crea una fuga de memoria.

El mismo problema de pérdida de memoria con el siguiente código (ligeramente diferente):

dataloader_iterator = iter(dataloader)
for i in range(iterations):     
    try:
        X, Y = next(dataloader_iterator)
    except:
        dataloader_iterator = iter(train_loader)
        X, Y = next(dataloader_iterator)
    do_backprop(X, Y)

La ocupación de la memoria aumenta continuamente durante el ciclo for. Podría abrir un nuevo problema con más información (si aún no lo he hecho)

Esto podría no ser una pérdida de memoria, sino simplemente el hecho de que su ciclo está extremadamente ocupado generando procesos más rápido de lo que podemos terminarlos. Los iteradores de DataLoader no están destinados a ser objetos de muy corta duración.

Mi comentario anterior era incorrecto. Descubrí que la fuga estaba en otra parte del código (me estaba aferrando a vars sin desconectarme para aquellos que tienen curiosidad).

obtenga "BrokenPipeError: [Errno 32] Broken pipe" al intentar siguiente (iter (cargador de datos))

Usé este método para recuperar lotes para entrenar en un bucle:

    for i in range(n):
       batch = next(iter(data_loader))

Me di cuenta de que sigo recibiendo el mismo lote, como si el __getitem__ subyacente del conjunto de datos siguiera obteniendo el mismo índice item .
¿Esto es normal?

@shaibagon
No está muy bien documentado, pero cuando haces iter(dataloader) , creas un objeto de clase _DataLoaderIter y, en el ciclo, crearás el mismo objeto n veces y recuperarás solo el primer lote.
Una solución consiste en crear un _DataLoaderIter fuera del bucle e iterarlo. El problema es que una vez que se recuperan todos los lotes, _DataLoaderIter generará un error StopIteration .

Para evitar problemas, lo que estoy haciendo actualmente es lo siguiente:

    dataloader_iterator = iter(dataloader)
    for i in range(iterations):
        try:
            data, target = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(dataloader)
            data, target = next(dataloader_iterator)
        do_something()

Es muy feo pero funciona bien.

Se puede encontrar un problema similar aquí . Espero que las soluciones propuestas en este hilo también puedan ayudar a las personas allí.

@ srossi93 buena solución. A veces recibo una excepción ignorada cuando finaliza el ciclo de iteraciones: ConnectionResetError: [Errno 104] Connection reset by peer .

Parece ser causado por el multiprocesamiento. Establecer num_workers en el cargador de datos en 0 hace que el error desaparezca. ¿Alguna otra solución?

gracias @srossi93

¿Quizás este código es un poco mejor?

def inf_train_gen():
    while True:
        for images, targets in enumerate(dataloader):
            yield images, targets
gen = inf_train_gen
for it in range(num_iters):
    images, targets = gen.next()    

¿Se mezclará el conjunto de datos si utilizo el código proporcionado anteriormente?
dataloader_iterator = iter(dataloader) for i in range(iterations): try: X, Y = next(dataloader_iterator) except: dataloader_iterator = iter(train_loader) X, Y = next(dataloader_iterator) do_backprop(X, Y)

¿Por qué esto no es un generador/iterador ya?

@ Yamin05114 Ejecuté un pequeño ejemplo para ver si el resultado de iter (cargador de datos) se baraja cada vez que se reinicia. Si ejecuta este pequeño script a continuación, puede mirar el pequeño conjunto de declaraciones impresas para confirmar por sí mismo que el pedido se barajó. Esto no es una prueba en general, pero es una evidencia convincente de que los datos se barajan cada vez que llamamos a iter(train_loader).

import torch
from torch.utils.data import Dataset, DataLoader

dataset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=3)
iterloader = iter(dataloader)

for i in range(0, 12):

    try:
        batch = next(iterloader)
    except StopIteration:
        iterloader = iter(dataloader)
        batch = next(iterloader)

    print("iteration" + str(i))
    print(batch)

Además, no pude reproducir el error de @shaibagon ... el siguiente código parece producir lotes distintos (usando las mismas variables definidas anteriormente), por lo que no estoy seguro de qué sucedió allí.

for i in range(0, 12):
    batch = next(iter(dataloader))
    print("iteration: " + str(i))
    print(batch)

Si tiene un objeto dataset que hereda data.Dataset de pytorch, debe anular el método __getitem__ , que usa idx como argumento. Por lo tanto, puede acceder directamente:

**some dataset instance called _data_
data=Dataset(**kwargs)
for i in range(10):
     data[i]

o

for i in range(10):
     data_batch.__getitem__(i)
¿Fue útil esta página
0 / 5 - 0 calificaciones