Pytorch: Obtenez un seul lot de DataLoader sans itérer

Créé le 26 juin 2017  ·  18Commentaires  ·  Source: pytorch/pytorch

Est-il possible d'obtenir un seul lot à partir d'un DataLoader ? Actuellement, je configure une boucle for et renvoie un lot manuellement.
S'il n'y a pas moyen de le faire avec le DataLoader actuellement, je serais heureux de travailler sur l'ajout de la fonctionnalité.

Commentaire le plus utile

next(iter(data_loader)) ?

Tous les 18 commentaires

next(iter(data_loader)) ?

Cool, c'est beaucoup mieux que ce que j'utilisais.
Merci!

Cette réponse provoque une fuite de mémoire dans ma formation avec une augmentation linéaire de la mémoire RAM, alors qu'une occupation constante avec une boucle for régulière (et exactement le même code dans la boucle) :/

:+1: à @hyperfraise. cela crée une fuite de mémoire.

Même problème de fuite mémoire avec le code suivant (légèrement différent) :

dataloader_iterator = iter(dataloader)
for i in range(iterations):     
    try:
        X, Y = next(dataloader_iterator)
    except:
        dataloader_iterator = iter(train_loader)
        X, Y = next(dataloader_iterator)
    do_backprop(X, Y)

L'occupation de la mémoire augmente continuellement pendant la boucle for. Je pourrais ouvrir un nouveau sujet avec plus d'informations (si ce n'est pas encore fait)

Il ne s'agit peut-être pas d'une fuite de mémoire, mais simplement du fait que votre boucle est extrêmement occupée à générer des processus plus rapidement que nous ne pouvons même les terminer. Les itérateurs DataLoader ne sont pas censés être des objets de très courte durée

Mon commentaire précédent était incorrect. J'ai découvert que la fuite était ailleurs dans le code (je m'accrochais aux vars sans les détacher pour ceux qui sont curieux).

obtenir "BrokenPipeError: [Errno 32] Tuyau cassé" lors de la prochaine tentative (iter (dataloader))

J'ai utilisé cette méthode pour récupérer des lots pour la formation en boucle:

    for i in range(n):
       batch = next(iter(data_loader))

J'ai remarqué que je reçois toujours le même lot, comme le __getitem__ sous-jacent de l'ensemble de données qui reçoit toujours le même index item .
Est-ce normal?

@shaibagon
Ce n'est pas très bien documenté mais lorsque vous faites iter(dataloader) vous créez un objet de classe _DataLoaderIter et, dans la boucle, vous allez créer le même objet n fois et récupérer le premier lot uniquement.
Une solution de contournement consiste à créer un _DataLoaderIter en dehors de la boucle et à itérer dessus. Le problème est qu'une fois tous les lots récupérés, _DataLoaderIter génère une erreur StopIteration .

Pour éviter les problèmes, ce que je fais actuellement est le suivant:

    dataloader_iterator = iter(dataloader)
    for i in range(iterations):
        try:
            data, target = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(dataloader)
            data, target = next(dataloader_iterator)
        do_something()

C'est très moche mais ça marche très bien.

un problème similaire peut être trouvé ici J'espère que les solutions proposées dans ce fil pourront également aider les gens là-bas.

@ srossi93 belle solution. Parfois, j'obtiens une exception ignorée à la fin du cycle d'itérations : ConnectionResetError: [Errno 104] Connection reset by peer .

Semble être causé par le multitraitement. La définition de num_workers sur le chargeur de données sur 0 fait disparaître l'erreur. D'autres solutions ?

merci @srossi93

Peut-être que ce code est un peu mieux ?

def inf_train_gen():
    while True:
        for images, targets in enumerate(dataloader):
            yield images, targets
gen = inf_train_gen
for it in range(num_iters):
    images, targets = gen.next()    

L'ensemble de données sera-t-il mélangé si j'utilise le code fourni ci-dessus ?
dataloader_iterator = iter(dataloader) for i in range(iterations): try: X, Y = next(dataloader_iterator) except: dataloader_iterator = iter(train_loader) X, Y = next(dataloader_iterator) do_backprop(X, Y)

pourquoi n'est-ce pas déjà un générateur/itérateur ?

@ Yamin05114 J'ai exécuté un petit exemple pour voir si le résultat de iter (dataloader) est mélangé à chaque fois qu'il est réinitialisé. Si vous exécutez ce petit script ci-dessous, vous pouvez consulter le petit ensemble d'instructions d'impression pour confirmer par vous-même que la commande est bien mélangée. Ce n'est pas une preuve en général, mais c'est une preuve convaincante que les données sont mélangées à chaque fois que nous appelons iter(train_loader).

import torch
from torch.utils.data import Dataset, DataLoader

dataset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=3)
iterloader = iter(dataloader)

for i in range(0, 12):

    try:
        batch = next(iterloader)
    except StopIteration:
        iterloader = iter(dataloader)
        batch = next(iterloader)

    print("iteration" + str(i))
    print(batch)

De plus, je n'ai pas pu reproduire l'erreur de @shaibagon ... le code ci-dessous semble produire des lots distincts (utilisant les mêmes variables que celles définies ci-dessus), donc je ne sais pas ce qui s'est passé là-bas.

for i in range(0, 12):
    batch = next(iter(dataloader))
    print("iteration: " + str(i))
    print(batch)

Si vous avez un objet dataset qui hérite data.Dataset de pytorch, il doit remplacer la méthode __getitem__ , qui utilise idx comme argument. Vous pouvez donc y accéder directement :

**some dataset instance called _data_
data=Dataset(**kwargs)
for i in range(10):
     data[i]

ou

for i in range(10):
     data_batch.__getitem__(i)
Cette page vous a été utile?
0 / 5 - 0 notes