Pytorch: Dapatkan satu batch dari DataLoader tanpa iterasi

Dibuat pada 26 Jun 2017  ·  18Komentar  ·  Sumber: pytorch/pytorch

Apakah mungkin untuk mendapatkan satu batch dari DataLoader? Saat ini, saya menyiapkan for loop dan mengembalikan batch secara manual.
Jika tidak ada cara untuk melakukan ini dengan DataLoader saat ini, saya akan dengan senang hati menambahkan fungsionalitas.

Komentar yang paling membantu

next(iter(data_loader)) ?

Semua 18 komentar

next(iter(data_loader)) ?

Keren, itu jauh lebih baik daripada yang saya gunakan.
Terima kasih!

Jawaban itu memicu kebocoran memori dalam pelatihan saya dengan augmentasi linier memori RAM, sementara pendudukan konstan dengan loop for reguler (dan kode yang persis sama dalam loop):/

:+1: ke @hyperfraise. ini menciptakan kebocoran memori.

Masalah kebocoran memori yang sama dengan kode berikut (sedikit berbeda):

dataloader_iterator = iter(dataloader)
for i in range(iterations):     
    try:
        X, Y = next(dataloader_iterator)
    except:
        dataloader_iterator = iter(train_loader)
        X, Y = next(dataloader_iterator)
    do_backprop(X, Y)

Pendudukan memori terus meningkat selama for-loop. Saya mungkin membuka masalah baru dengan informasi lebih lanjut (jika belum selesai)

Ini mungkin bukan kebocoran memori tetapi hanya fakta bahwa loop Anda sangat sibuk dengan proses pemijahan lebih cepat daripada yang dapat kami hentikan. Iterator DataLoader tidak dimaksudkan untuk menjadi objek yang berumur pendek

Komentar saya sebelumnya salah. Saya menemukan bahwa kebocoran ada di tempat lain dalam kode (saya bergantung pada vars tanpa melepaskan untuk mereka yang penasaran).

dapatkan "BrokenPipeError: [Errno 32] Broken pipe" saat mencoba berikutnya(iter(dataloader))

Saya menggunakan metode ini untuk mengambil batch untuk pelatihan dalam satu lingkaran:

    for i in range(n):
       batch = next(iter(data_loader))

Saya perhatikan saya terus mendapatkan kumpulan yang sama, seperti __getitem__ yang mendasari dataset terus mendapatkan indeks item yang sama.
Apakah ini normal?

@shaibagon
Itu tidak didokumentasikan dengan baik tetapi ketika Anda melakukan iter(dataloader) Anda membuat objek kelas _DataLoaderIter dan, dalam loop, Anda akan membuat objek yang sama n kali dan mengambil batch pertama saja.
Solusinya adalah membuat _DataLoaderIter di luar loop dan mengulanginya. Masalahnya adalah setelah semua kumpulan diambil, _DataLoaderIter akan memunculkan kesalahan StopIteration .

Untuk menghindari masalah, apa yang saya lakukan saat ini adalah sebagai berikut:

    dataloader_iterator = iter(dataloader)
    for i in range(iterations):
        try:
            data, target = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(dataloader)
            data, target = next(dataloader_iterator)
        do_something()

Ini sangat jelek tetapi berfungsi dengan baik.

masalah serupa dapat ditemukan di sini Saya harap solusi yang diusulkan di utas ini dapat membantu orang-orang di sana juga.

@srossi93 solusi yang bagus. Terkadang saya mendapatkan pengecualian yang diabaikan ketika siklus iterasi berakhir: ConnectionResetError: [Errno 104] Connection reset by peer .

Tampaknya disebabkan oleh multiprosesor. Menyetel num_workers pada dataloader ke 0 membuat kesalahan hilang. Ada solusi lain?

thx @srossi93

Mungkin kode ini sedikit lebih baik?

def inf_train_gen():
    while True:
        for images, targets in enumerate(dataloader):
            yield images, targets
gen = inf_train_gen
for it in range(num_iters):
    images, targets = gen.next()    

Apakah dataset akan diacak jika saya menggunakan kode yang diberikan di atas?
dataloader_iterator = iter(dataloader) for i in range(iterations): try: X, Y = next(dataloader_iterator) except: dataloader_iterator = iter(train_loader) X, Y = next(dataloader_iterator) do_backprop(X, Y)

mengapa ini bukan generator/iterator?

@Yamin05114 Saya menjalankan contoh kecil untuk melihat apakah hasil iter(dataloader) dikocok setiap kali disetel ulang. Jika Anda menjalankan skrip kecil di bawah ini, Anda dapat melihat kumpulan kecil pernyataan cetak untuk mengonfirmasi sendiri bahwa pesanan memang diacak. Ini bukan bukti secara umum, tetapi ini adalah bukti yang meyakinkan bahwa data dikocok setiap kali kita memanggil iter(train_loader).

import torch
from torch.utils.data import Dataset, DataLoader

dataset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=3)
iterloader = iter(dataloader)

for i in range(0, 12):

    try:
        batch = next(iterloader)
    except StopIteration:
        iterloader = iter(dataloader)
        batch = next(iterloader)

    print("iteration" + str(i))
    print(batch)

Selain itu, saya tidak dapat mereproduksi kesalahan @shaibagon ... kode di bawah ini tampaknya menghasilkan kumpulan yang berbeda (menggunakan variabel yang sama seperti yang didefinisikan di atas), jadi tidak yakin apa yang terjadi di sana.

for i in range(0, 12):
    batch = next(iter(dataloader))
    print("iteration: " + str(i))
    print(batch)

Jika Anda memiliki objek dataset yang mewarisi data.Dataset dari pytorch, itu harus menimpa metode __getitem__ , yang menggunakan idx sebagai argumen. Oleh karena itu Anda dapat mengaksesnya secara langsung:

**some dataset instance called _data_
data=Dataset(**kwargs)
for i in range(10):
     data[i]

atau

for i in range(10):
     data_batch.__getitem__(i)
Apakah halaman ini membantu?
0 / 5 - 0 peringkat