Pytorch: Obtenha um único lote do DataLoader sem iterar

Criado em 26 jun. 2017  ·  18Comentários  ·  Fonte: pytorch/pytorch

É possível obter um único lote de um DataLoader? Atualmente, configuro um loop for e retorno um lote manualmente.
Se não houver uma maneira de fazer isso com o DataLoader atualmente, ficaria feliz em trabalhar para adicionar a funcionalidade.

Comentários muito úteis

next(iter(data_loader)) ?

Todos 18 comentários

next(iter(data_loader)) ?

Legal, isso é muito melhor do que o que eu estava usando.
Obrigado!

Essa resposta provoca um vazamento de memória no meu treinamento com aumento linear da memória RAM, enquanto ocupação constante com um loop for regular (e exatamente o mesmo código no loop) :/

:+1: para @hyperfraise. isso cria um vazamento de memória.

Mesmo problema de vazamento de memória com o seguinte código (um pouco diferente):

dataloader_iterator = iter(dataloader)
for i in range(iterations):     
    try:
        X, Y = next(dataloader_iterator)
    except:
        dataloader_iterator = iter(train_loader)
        X, Y = next(dataloader_iterator)
    do_backprop(X, Y)

A ocupação da memória aumenta continuamente durante o loop for. Posso abrir um novo problema com mais informações (se ainda não o tiver feito)

Isso pode não ser um vazamento de memória, mas simplesmente o fato de que seu loop está extremamente ocupado gerando processos mais rápido do que podemos encerrá-los. Os iteradores do DataLoader não devem ser objetos de vida muito curta

Meu comentário anterior estava incorreto. Descobri que o vazamento estava em outro lugar no código (estava me agarrando a vars sem desapegar para quem tiver curiosidade).

obter "BrokenPipeError: [Errno 32] Broken pipe" ao tentar next(iter(dataloader))

Eu usei este método para recuperar lotes para treinamento em um loop:

    for i in range(n):
       batch = next(iter(data_loader))

Percebi que continuo recebendo o mesmo lote, como o __getitem__ subjacente do conjunto de dados continua recebendo o mesmo índice item .
Isso é normal?

@shaibagon
Não está muito bem documentado, mas quando você faz iter(dataloader) você cria um objeto da classe _DataLoaderIter e, no loop, você cria o mesmo objeto n vezes e recupera apenas o primeiro lote.
Uma solução alternativa é criar um _DataLoaderIter fora do loop e iterar sobre ele. O problema é que assim que todos os lotes forem recuperados, _DataLoaderIter gerará um erro StopIteration .

Para evitar problemas, o que estou fazendo atualmente é o seguinte:

    dataloader_iterator = iter(dataloader)
    for i in range(iterations):
        try:
            data, target = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(dataloader)
            data, target = next(dataloader_iterator)
        do_something()

É muito feio, mas funciona bem.

problema semelhante pode ser encontrado aqui Espero que as soluções propostas neste tópico possam ajudar as pessoas de lá também.

@srossi93 boa solução. Às vezes, recebo uma exceção ignorada quando o ciclo de iterações termina: ConnectionResetError: [Errno 104] Connection reset by peer .

Parece ser causado por multiprocessamento. Definir o num_workers no dataloader como 0 faz com que o erro desapareça. Alguma outra solução?

obrigado @srossi93

Talvez este código seja um pouco melhor?

def inf_train_gen():
    while True:
        for images, targets in enumerate(dataloader):
            yield images, targets
gen = inf_train_gen
for it in range(num_iters):
    images, targets = gen.next()    

O conjunto de dados será embaralhado se eu usar o código fornecido acima?
dataloader_iterator = iter(dataloader) for i in range(iterations): try: X, Y = next(dataloader_iterator) except: dataloader_iterator = iter(train_loader) X, Y = next(dataloader_iterator) do_backprop(X, Y)

por que isso não é um gerador/iterador já?

@Yamin05114 Eu executei um pequeno exemplo para ver se o resultado de iter (dataloader) é embaralhado toda vez que é redefinido. Se você executar este pequeno script abaixo, poderá examinar o pequeno conjunto de instruções de impressão para confirmar por si mesmo que o pedido está realmente embaralhado. Esta não é uma prova em geral, mas é uma evidência convincente de que os dados são embaralhados toda vez que chamamos iter(train_loader).

import torch
from torch.utils.data import Dataset, DataLoader

dataset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=3)
iterloader = iter(dataloader)

for i in range(0, 12):

    try:
        batch = next(iterloader)
    except StopIteration:
        iterloader = iter(dataloader)
        batch = next(iterloader)

    print("iteration" + str(i))
    print(batch)

Além disso, não consegui reproduzir o erro do @shaibagon ... o código abaixo parece produzir lotes distintos (usando as mesmas variáveis ​​definidas acima), então não tenho certeza do que aconteceu lá.

for i in range(0, 12):
    batch = next(iter(dataloader))
    print("iteration: " + str(i))
    print(batch)

Se você tem um objeto dataset que herda data.Dataset de pytorch, ele deve substituir o método __getitem__ , que usa idx como argumento. Portanto, você pode acessá-lo diretamente:

**some dataset instance called _data_
data=Dataset(**kwargs)
for i in range(10):
     data[i]

ou

for i in range(10):
     data_batch.__getitem__(i)
Esta página foi útil?
0 / 5 - 0 avaliações

Questões relacionadas

SeparateReality picture SeparateReality  ·  3Comentários

bartvm picture bartvm  ·  3Comentários

NgPDat picture NgPDat  ·  3Comentários

a1363901216 picture a1363901216  ·  3Comentários

mishraswapnil picture mishraswapnil  ·  3Comentários