Pytorch: ๋ฐ˜๋ณตํ•˜์ง€ ์•Š๊ณ  DataLoader์—์„œ ๋‹จ์ผ ๋ฐฐ์น˜ ๊ฐ€์ ธ์˜ค๊ธฐ

์— ๋งŒ๋“  2017๋…„ 06์›” 26์ผ  ยท  18์ฝ”๋ฉ˜ํŠธ  ยท  ์ถœ์ฒ˜: pytorch/pytorch

DataLoader์—์„œ ๋‹จ์ผ ๋ฐฐ์น˜๋ฅผ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๊นŒ? ํ˜„์žฌ for ๋ฃจํ”„๋ฅผ ์„ค์ •ํ•˜๊ณ  ๋ฐฐ์น˜๋ฅผ ์ˆ˜๋™์œผ๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
ํ˜„์žฌ DataLoader๋กœ ์ด ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์ด ์—†๋‹ค๋ฉด ๊ธฐ๊บผ์ด ๊ธฐ๋Šฅ์„ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๊ฐ€์žฅ ์œ ์šฉํ•œ ๋Œ“๊ธ€

next(iter(data_loader)) ?

๋ชจ๋“  18 ๋Œ“๊ธ€

next(iter(data_loader)) ?

์ข‹์•„์š”, ์ œ๊ฐ€ ์‚ฌ์šฉํ–ˆ๋˜ ๊ฒƒ๋ณด๋‹ค ํ›จ์”ฌ ๋‚ซ์Šต๋‹ˆ๋‹ค.
๊ฐ์‚ฌ ํ•ด์š”!

๊ทธ ๋Œ€๋‹ต์€ RAM ๋ฉ”๋ชจ๋ฆฌ์˜ ์„ ํ˜• ์ฆ๊ฐ€๋กœ ํ›ˆ๋ จ์—์„œ ๋ฉ”๋ชจ๋ฆฌ ๋ˆ„์ˆ˜๋ฅผ ์œ ๋ฐœํ•˜๋Š” ๋ฐ˜๋ฉด ์ผ๋ฐ˜ for ๋ฃจํ”„(๋ฐ ๋ฃจํ”„์—์„œ ์ •ํ™•ํžˆ ๋™์ผํ•œ ์ฝ”๋“œ)๋กœ ๊ณ„์† ์ ์œ ํ•ฉ๋‹ˆ๋‹ค.

:+1: @hyperfraise๋กœ. ์ด๊ฒƒ์€ ๋ฉ”๋ชจ๋ฆฌ ๋ˆ„์ˆ˜๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

๋‹ค์Œ(์•ฝ๊ฐ„ ๋‹ค๋ฅธ) ์ฝ”๋“œ์—์„œ ๋™์ผํ•œ ๋ฉ”๋ชจ๋ฆฌ ๋ˆ„์ˆ˜ ๋ฌธ์ œ:

dataloader_iterator = iter(dataloader)
for i in range(iterations):     
    try:
        X, Y = next(dataloader_iterator)
    except:
        dataloader_iterator = iter(train_loader)
        X, Y = next(dataloader_iterator)
    do_backprop(X, Y)

for ๋ฃจํ”„ ๋™์•ˆ ๋ฉ”๋ชจ๋ฆฌ ์ ์œ ๊ฐ€ ๊ณ„์† ์ฆ๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. ์ถ”๊ฐ€ ์ •๋ณด๊ฐ€ ํฌํ•จ๋œ ์ƒˆ ๋ฌธ์ œ๋ฅผ ์—ด โ€‹โ€‹์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(์•„์ง ์™„๋ฃŒ๋˜์ง€ ์•Š์€ ๊ฒฝ์šฐ).

์ด๊ฒƒ์€ ๋ฉ”๋ชจ๋ฆฌ ๋ˆ„์ˆ˜๊ฐ€ ์•„๋‹ˆ๋ผ ๋‹จ์ˆœํžˆ ๋ฃจํ”„๊ฐ€ ์ข…๋ฃŒํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒƒ๋ณด๋‹ค ๋น ๋ฅด๊ฒŒ ํ”„๋กœ์„ธ์Šค๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋ฐ ๋งค์šฐ ๋ฐ”์˜๋‹ค๋Š” ์‚ฌ์‹ค์ž…๋‹ˆ๋‹ค. DataLoader ๋ฐ˜๋ณต์ž๋Š” ์ˆ˜๋ช…์ด ๋งค์šฐ ์งง์€ ๊ฐœ์ฒด๋ฅผ ์˜๋ฏธํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

๋‚ด ์ด์ „ ๋Œ“๊ธ€์ด ์ž˜๋ชป๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๋‚˜๋Š” ๋ˆ„์ˆ˜๊ฐ€ ์ฝ”๋“œ์˜ ๋‹ค๋ฅธ ๊ณณ์— ์žˆ๋‹ค๋Š” ๊ฒƒ์„ ๋ฐœ๊ฒฌํ–ˆ์Šต๋‹ˆ๋‹ค(๋‚˜๋Š” ํ˜ธ๊ธฐ์‹ฌ ๋งŽ์€ ์‚ฌ๋žŒ๋“ค์„ ์œ„ํ•ด ๋ถ„๋ฆฌํ•˜์ง€ ์•Š๊ณ  vars์— ๋งค๋‹ฌ๋ ค ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค).

next(iter(dataloader)) ์‹œ๋„ ์‹œ "BrokenPipeError: [Errno 32] Broken pipe"๊ฐ€ ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค.

์ด ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฃจํ”„์—์„œ ํ›ˆ๋ จ์„ ์œ„ํ•œ ๋ฐฐ์น˜๋ฅผ ๊ฒ€์ƒ‰ํ–ˆ์Šต๋‹ˆ๋‹ค.

    for i in range(n):
       batch = next(iter(data_loader))

๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ๊ธฐ๋ณธ __getitem__ ๊ฐ€ ๋™์ผํ•œ item ์ธ๋ฑ์Šค๋ฅผ ๊ณ„์† ๊ฐ€์ ธ์˜ค๋Š” ๊ฒƒ์ฒ˜๋Ÿผ ๋™์ผํ•œ ๋ฐฐ์น˜๋ฅผ ๊ณ„์† ๊ฐ€์ ธ์˜ค๋Š” ๊ฒƒ์œผ๋กœ ๋‚˜ํƒ€๋‚ฌ์Šต๋‹ˆ๋‹ค.
์ด๊ฒŒ ์ •์ƒ์ธ๊ฐ€์š”?

@shaibagon
์ž˜ ๋ฌธ์„œํ™”๋˜์–ด ์žˆ์ง€๋Š” ์•Š์ง€๋งŒ iter(dataloader) ๋ฅผ ์ˆ˜ํ–‰ํ•  ๋•Œ _DataLoaderIter ํด๋ž˜์Šค์˜ ๊ฐœ์ฒด๋ฅผ ๋งŒ๋“ค๊ณ  ๋ฃจํ”„์—์„œ ๋™์ผํ•œ ๊ฐœ์ฒด๋ฅผ n ๋ฒˆ ๋งŒ๋“ค๊ณ  ์ฒซ ๋ฒˆ์งธ ๋ฐฐ์น˜๋งŒ ๊ฒ€์ƒ‰ํ•ฉ๋‹ˆ๋‹ค.
ํ•ด๊ฒฐ ๋ฐฉ๋ฒ•์€ ๋ฃจํ”„ ์™ธ๋ถ€์— _DataLoaderIter๋ฅผ ๋งŒ๋“ค๊ณ  ๋ฐ˜๋ณตํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋ฌธ์ œ๋Š” ๋ชจ๋“  ๋ฐฐ์น˜๊ฐ€ ๊ฒ€์ƒ‰๋˜๋ฉด _DataLoaderIter๊ฐ€ StopIteration ์˜ค๋ฅ˜ ๋ฅผ ๋ฐœ์ƒ์‹œํ‚จ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋ฌธ์ œ๋ฅผ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด ํ˜„์žฌ ํ•˜๊ณ  ์žˆ๋Š” ์ž‘์—…์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

    dataloader_iterator = iter(dataloader)
    for i in range(iterations):
        try:
            data, target = next(dataloader_iterator)
        except StopIteration:
            dataloader_iterator = iter(dataloader)
            data, target = next(dataloader_iterator)
        do_something()

๊ทธ๊ฒƒ์€ ๋งค์šฐ ์ถ”์•…ํ•˜์ง€๋งŒ ์ž˜ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค.

์—ฌ๊ธฐ ์—์„œ ์œ ์‚ฌํ•œ ๋ฌธ์ œ๋ฅผ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ์Šค๋ ˆ๋“œ์—์„œ ์ œ์•ˆ๋œ ์†”๋ฃจ์…˜์ด ์‚ฌ๋žŒ๋“ค์—๊ฒŒ๋„ ๋„์›€์ด ๋˜์—ˆ์œผ๋ฉด ํ•ฉ๋‹ˆ๋‹ค.

@ srossi93 ์ข‹์€ ์†”๋ฃจ์…˜์ž…๋‹ˆ๋‹ค. ๋•Œ๋•Œ๋กœ ๋ฐ˜๋ณต ์ฃผ๊ธฐ๊ฐ€ ๋๋‚  ๋•Œ ๋ฌด์‹œ๋œ ์˜ˆ์™ธ๊ฐ€ ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค: ConnectionResetError: [Errno 104] Connection reset by peer .

๋‹ค์ค‘ ์ฒ˜๋ฆฌ๋กœ ์ธํ•ด ๋ฐœ์ƒํ•˜๋Š” ๊ฒƒ์œผ๋กœ ๋ณด์ž…๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ ๋กœ๋”์˜ num_workers๋ฅผ 0์œผ๋กœ ์„ค์ •ํ•˜๋ฉด ์˜ค๋ฅ˜๊ฐ€ ์‚ฌ๋ผ์ง‘๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ์†”๋ฃจ์…˜์ด ์žˆ์Šต๋‹ˆ๊นŒ?

thx @srossi93

์–ด์ฉŒ๋ฉด ์ด ์ฝ”๋“œ๊ฐ€ ์กฐ๊ธˆ ๋” ๋‚˜์„๊นŒ์š”?

def inf_train_gen():
    while True:
        for images, targets in enumerate(dataloader):
            yield images, targets
gen = inf_train_gen
for it in range(num_iters):
    images, targets = gen.next()    

์œ„์— ์ œ๊ณต๋œ ์ฝ”๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๋ฐ์ดํ„ฐ ์„ธํŠธ๊ฐ€ ์„ž์ด๋‚˜์š”?
dataloader_iterator = iter(dataloader) for i in range(iterations): try: X, Y = next(dataloader_iterator) except: dataloader_iterator = iter(train_loader) X, Y = next(dataloader_iterator) do_backprop(X, Y)

์ด๊ฒƒ์ด ์ด๋ฏธ ์ƒ์„ฑ๊ธฐ/๋ฐ˜๋ณต์ž๊ฐ€ ์•„๋‹Œ ์ด์œ ๋Š” ๋ฌด์—‡์ž…๋‹ˆ๊นŒ?

@Yamin05114 iter(dataloader)์˜ ๊ฒฐ๊ณผ๊ฐ€ ์žฌ์„ค์ •๋  ๋•Œ๋งˆ๋‹ค ์…”ํ”Œ๋˜๋Š”์ง€ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด ์ž‘์€ ์˜ˆ์ œ๋ฅผ ์‹คํ–‰ํ–ˆ์Šต๋‹ˆ๋‹ค. ์•„๋ž˜์—์„œ ์ด ์ž‘์€ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•˜๋ฉด ์ž‘์€ ์ธ์‡„ ๋ฌธ ์„ธํŠธ๋ฅผ ๋ณด๊ณ  ์ฃผ๋ฌธ์ด ์‹ค์ œ๋กœ ์„ž์ธ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ์ผ๋ฐ˜์ ์ธ ์ฆ๊ฑฐ๋Š” ์•„๋‹ˆ์ง€๋งŒ iter(train_loader)๋ฅผ ํ˜ธ์ถœํ•  ๋•Œ๋งˆ๋‹ค ๋ฐ์ดํ„ฐ๊ฐ€ ๋’ค์„ž์ธ๋‹ค๋Š” ์„ค๋“๋ ฅ ์žˆ๋Š” ์ฆ๊ฑฐ์ž…๋‹ˆ๋‹ค.

import torch
from torch.utils.data import Dataset, DataLoader

dataset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=3)
iterloader = iter(dataloader)

for i in range(0, 12):

    try:
        batch = next(iterloader)
    except StopIteration:
        iterloader = iter(dataloader)
        batch = next(iterloader)

    print("iteration" + str(i))
    print(batch)

๋˜ํ•œ @shaibagon ์˜ ์˜ค๋ฅ˜๋ฅผ ์žฌํ˜„ํ•  ์ˆ˜ ์—†์—ˆ์Šต๋‹ˆ๋‹ค...์•„๋ž˜ ์ฝ”๋“œ๋Š” ๋ณ„๊ฐœ์˜ ๋ฐฐ์น˜(์œ„์—์„œ ์ •์˜ํ•œ ๊ฒƒ๊ณผ ๋™์ผํ•œ ๋ณ€์ˆ˜ ์‚ฌ์šฉ)๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ ๊ฐ™์•„์„œ ๊ฑฐ๊ธฐ์„œ ๋ฌด์Šจ ์ผ์ด ์ผ์–ด๋‚ฌ๋Š”์ง€ ํ™•์‹คํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

for i in range(0, 12):
    batch = next(iter(dataloader))
    print("iteration: " + str(i))
    print(batch)

pytorch์—์„œ data.Dataset ๋ฅผ ์ƒ์†ํ•˜๋Š” dataset ๊ฐœ์ฒด๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ idx๋ฅผ ์ธ์ˆ˜๋กœ ์‚ฌ์šฉํ•˜๋Š” __getitem__ ๋ฉ”์„œ๋“œ๋ฅผ ์žฌ์ •์˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ง์ ‘ ์•ก์„ธ์Šคํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

**some dataset instance called _data_
data=Dataset(**kwargs)
for i in range(10):
     data[i]

๋˜๋Š”

for i in range(10):
     data_batch.__getitem__(i)
์ด ํŽ˜์ด์ง€๊ฐ€ ๋„์›€์ด ๋˜์—ˆ๋‚˜์š”?
0 / 5 - 0 ๋“ฑ๊ธ‰