Pytorch: احصل على دفعة واحدة من DataLoader بدون تكرار

تم إنشاؤها على ٢٦ يونيو ٢٠١٧  ·  18تعليقات  ·  مصدر: pytorch/pytorch

هل من الممكن الحصول على دفعة واحدة من DataLoader؟ حاليًا ، أقوم بإعداد حلقة for وأعيد دفعة يدويًا.
إذا لم تكن هناك طريقة للقيام بذلك باستخدام DataLoader حاليًا ، فسيسعدني العمل على إضافة الوظيفة.

التعليق الأكثر فائدة

next(iter(data_loader)) ؟

ال 18 كومينتر

next(iter(data_loader)) ؟

رائع ، هذا أفضل بكثير مما كنت أستخدمه.
شكرا!

تثير هذه الإجابة تسريبًا للذاكرة في تدريبي من خلال الزيادة الخطية لذاكرة RAM ، أثناء العمل المستمر مع حلقة for عادية (ونفس الكود في الحلقة تمامًا): /

: +1: tohyperfraise. يؤدي هذا إلى حدوث تسرب للذاكرة.

نفس مشكلة تسرب الذاكرة مع الكود التالي (مختلف قليلاً):

dataloader_iterator = iter(dataloader)
for i in range(iterations):     
    try:
        X, Y = next(dataloader_iterator)
    except:
        dataloader_iterator = iter(train_loader)
        X, Y = next(dataloader_iterator)
    do_backprop(X, Y)

يزداد احتلال الذاكرة باستمرار أثناء الحلقة. قد أفتح مشكلة جديدة بمزيد من المعلومات (إذا لم يتم ذلك بعد)

قد لا يكون هذا تسربًا للذاكرة ولكن ببساطة حقيقة أن الحلقة الخاصة بك مشغولة للغاية بعمليات تفريخ أسرع مما يمكننا إنهاءها. لا يُقصد من مكررات DataLoader أن تكون كائنات قصيرة العمر

تعليقي السابق كان غير صحيح. اكتشفت أن التسريب كان في مكان آخر في الكود (كنت أتشبث به دون أن أفصل عن أولئك الذين لديهم فضول).

الحصول على "BrokenPipeError: [Errno 32] Broken pipe" عند المحاولة التالية (iter (أداة تحميل البيانات))

لقد استخدمت هذه الطريقة لاسترداد دفعات للتدريب في حلقة:

    for i in range(n):
       batch = next(iter(data_loader))

لقد لاحظت أنني أستمر في الحصول على نفس الدفعة ، مثل __getitem__ الأساسي لمجموعة البيانات يستمر في الحصول على نفس الفهرس item .
هل هذا طبيعي؟

تضمين التغريدة
لم يتم توثيقه جيدًا ولكن عندما تقوم بعمل iter(dataloader) تقوم بإنشاء كائن من class _DataLoaderIter ، وفي الحلقة ، ستقوم بإنشاء نفس الكائن n مرات واسترداد الدفعة الأولى فقط.
الحل هو إنشاء _DataLoaderIter خارج الحلقة وتكرارها. تكمن المشكلة في أنه بمجرد استرداد جميع الدُفعات ، فإن _DataLoaderIter سيرفع خطأ StopIteration .

لتجنب المشاكل ، ما أفعله حاليًا هو ما يلي:

    dataloader_iterator = iter(dataloader)
    for i in range(iterations):
        try:
            data, target = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(dataloader)
            data, target = next(dataloader_iterator)
        do_something()

إنه قبيح للغاية ولكنه يعمل بشكل جيد.

يمكن العثور على مشكلة مماثلة هنا وآمل أن تساعد الحلول المقترحة في هذا الموضوع الأشخاص هناك أيضًا.

@ srossi93 حل جميل. أحيانًا أحصل على استثناء تم تجاهله عند انتهاء دورة التكرارات: ConnectionResetError: [Errno 104] Connection reset by peer .

يبدو أنه ناتج عن المعالجة المتعددة. يؤدي تعيين عدد العاملين على أداة تحميل البيانات إلى 0 إلى اختفاء الخطأ. أي حلول أخرى؟

هههههههههههههههههههههههههههه

ربما هذا الرمز أفضل قليلاً؟

def inf_train_gen():
    while True:
        for images, targets in enumerate(dataloader):
            yield images, targets
gen = inf_train_gen
for it in range(num_iters):
    images, targets = gen.next()    

هل سيتم تبديل مجموعة البيانات عشوائيًا إذا استخدمت الكود المقدم أعلاه؟
dataloader_iterator = iter(dataloader) for i in range(iterations): try: X, Y = next(dataloader_iterator) except: dataloader_iterator = iter(train_loader) X, Y = next(dataloader_iterator) do_backprop(X, Y)

لماذا هذا ليس مولد / مكرر بالفعل؟

@ Yamin05114 قمت بتشغيل مثال صغير لمعرفة ما إذا كانت نتيجة iter (أداة تحميل البيانات) يتم خلطها عشوائيًا في كل مرة تتم إعادة تعيينها. إذا قمت بتشغيل هذا البرنامج النصي الصغير أدناه ، فيمكنك إلقاء نظرة على مجموعة صغيرة من عبارات الطباعة لتأكيد بنفسك أن الأمر قد تم خلطه بشكل عشوائي. هذا ليس دليلاً بشكل عام ، لكنه دليل مقنع على أن البيانات يتم خلطها في كل مرة نسميها iter (train_loader).

import torch
from torch.utils.data import Dataset, DataLoader

dataset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=3)
iterloader = iter(dataloader)

for i in range(0, 12):

    try:
        batch = next(iterloader)
    except StopIteration:
        iterloader = iter(dataloader)
        batch = next(iterloader)

    print("iteration" + str(i))
    print(batch)

بالإضافة إلى ذلك ، لم أتمكن من إعادة إنتاج خطأ shaibagon ... يبدو أن الكود أدناه ينتج دفعات مميزة (باستخدام نفس المتغيرات كما هو محدد أعلاه) ، لذلك لست متأكدًا مما حدث هناك.

for i in range(0, 12):
    batch = next(iter(dataloader))
    print("iteration: " + str(i))
    print(batch)

إذا كان لديك كائن dataset يرث data.Dataset من pytorch ، فيجب أن يتجاوز طريقة __getitem__ ، التي تستخدم idx كوسيطة. لذلك يمكنك الوصول إليه مباشرة:

**some dataset instance called _data_
data=Dataset(**kwargs)
for i in range(10):
     data[i]

أو

for i in range(10):
     data_batch.__getitem__(i)
هل كانت هذه الصفحة مفيدة؟
0 / 5 - 0 التقييمات