DataLoader์์ ๋จ์ผ ๋ฐฐ์น๋ฅผ ๊ฐ์ ธ์ฌ ์ ์์ต๋๊น? ํ์ฌ for ๋ฃจํ๋ฅผ ์ค์ ํ๊ณ ๋ฐฐ์น๋ฅผ ์๋์ผ๋ก ๋ฐํํฉ๋๋ค.
ํ์ฌ DataLoader๋ก ์ด ์์
์ ์ํํ ์ ์๋ ๋ฐฉ๋ฒ์ด ์๋ค๋ฉด ๊ธฐ๊บผ์ด ๊ธฐ๋ฅ์ ์ถ๊ฐํ ์ ์์ต๋๋ค.
next(iter(data_loader))
?
์ข์์, ์ ๊ฐ ์ฌ์ฉํ๋ ๊ฒ๋ณด๋ค ํจ์ฌ ๋ซ์ต๋๋ค.
๊ฐ์ฌ ํด์!
๊ทธ ๋๋ต์ RAM ๋ฉ๋ชจ๋ฆฌ์ ์ ํ ์ฆ๊ฐ๋ก ํ๋ จ์์ ๋ฉ๋ชจ๋ฆฌ ๋์๋ฅผ ์ ๋ฐํ๋ ๋ฐ๋ฉด ์ผ๋ฐ for ๋ฃจํ(๋ฐ ๋ฃจํ์์ ์ ํํ ๋์ผํ ์ฝ๋)๋ก ๊ณ์ ์ ์ ํฉ๋๋ค.
:+1: @hyperfraise๋ก. ์ด๊ฒ์ ๋ฉ๋ชจ๋ฆฌ ๋์๋ฅผ ๋ง๋ญ๋๋ค.
๋ค์(์ฝ๊ฐ ๋ค๋ฅธ) ์ฝ๋์์ ๋์ผํ ๋ฉ๋ชจ๋ฆฌ ๋์ ๋ฌธ์ :
dataloader_iterator = iter(dataloader)
for i in range(iterations):
try:
X, Y = next(dataloader_iterator)
except:
dataloader_iterator = iter(train_loader)
X, Y = next(dataloader_iterator)
do_backprop(X, Y)
for ๋ฃจํ ๋์ ๋ฉ๋ชจ๋ฆฌ ์ ์ ๊ฐ ๊ณ์ ์ฆ๊ฐํฉ๋๋ค. ์ถ๊ฐ ์ ๋ณด๊ฐ ํฌํจ๋ ์ ๋ฌธ์ ๋ฅผ ์ด โโ์ ์์ต๋๋ค(์์ง ์๋ฃ๋์ง ์์ ๊ฒฝ์ฐ).
์ด๊ฒ์ ๋ฉ๋ชจ๋ฆฌ ๋์๊ฐ ์๋๋ผ ๋จ์ํ ๋ฃจํ๊ฐ ์ข ๋ฃํ ์ ์๋ ๊ฒ๋ณด๋ค ๋น ๋ฅด๊ฒ ํ๋ก์ธ์ค๋ฅผ ์์ฑํ๋ ๋ฐ ๋งค์ฐ ๋ฐ์๋ค๋ ์ฌ์ค์ ๋๋ค. DataLoader ๋ฐ๋ณต์๋ ์๋ช ์ด ๋งค์ฐ ์งง์ ๊ฐ์ฒด๋ฅผ ์๋ฏธํ์ง ์์ต๋๋ค.
๋ด ์ด์ ๋๊ธ์ด ์๋ชป๋์์ต๋๋ค. ๋๋ ๋์๊ฐ ์ฝ๋์ ๋ค๋ฅธ ๊ณณ์ ์๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ์ต๋๋ค(๋๋ ํธ๊ธฐ์ฌ ๋ง์ ์ฌ๋๋ค์ ์ํด ๋ถ๋ฆฌํ์ง ์๊ณ vars์ ๋งค๋ฌ๋ ค ์์์ต๋๋ค).
next(iter(dataloader)) ์๋ ์ "BrokenPipeError: [Errno 32] Broken pipe"๊ฐ ๋ฐ์ํฉ๋๋ค.
์ด ๋ฐฉ๋ฒ์ ์ฌ์ฉํ์ฌ ๋ฃจํ์์ ํ๋ จ์ ์ํ ๋ฐฐ์น๋ฅผ ๊ฒ์ํ์ต๋๋ค.
for i in range(n):
batch = next(iter(data_loader))
๋ฐ์ดํฐ ์ธํธ์ ๊ธฐ๋ณธ __getitem__
๊ฐ ๋์ผํ item
์ธ๋ฑ์ค๋ฅผ ๊ณ์ ๊ฐ์ ธ์ค๋ ๊ฒ์ฒ๋ผ ๋์ผํ ๋ฐฐ์น๋ฅผ ๊ณ์ ๊ฐ์ ธ์ค๋ ๊ฒ์ผ๋ก ๋ํ๋ฌ์ต๋๋ค.
์ด๊ฒ ์ ์์ธ๊ฐ์?
@shaibagon
์ ๋ฌธ์ํ๋์ด ์์ง๋ ์์ง๋ง iter(dataloader)
๋ฅผ ์ํํ ๋ _DataLoaderIter ํด๋์ค์ ๊ฐ์ฒด๋ฅผ ๋ง๋ค๊ณ ๋ฃจํ์์ ๋์ผํ ๊ฐ์ฒด๋ฅผ n
๋ฒ ๋ง๋ค๊ณ ์ฒซ ๋ฒ์งธ ๋ฐฐ์น๋ง ๊ฒ์ํฉ๋๋ค.
ํด๊ฒฐ ๋ฐฉ๋ฒ์ ๋ฃจํ ์ธ๋ถ์ _DataLoaderIter๋ฅผ ๋ง๋ค๊ณ ๋ฐ๋ณตํ๋ ๊ฒ์
๋๋ค. ๋ฌธ์ ๋ ๋ชจ๋ ๋ฐฐ์น๊ฐ ๊ฒ์๋๋ฉด _DataLoaderIter๊ฐ StopIteration ์ค๋ฅ ๋ฅผ ๋ฐ์์ํจ๋ค๋ ๊ฒ์
๋๋ค.
๋ฌธ์ ๋ฅผ ํผํ๊ธฐ ์ํด ํ์ฌ ํ๊ณ ์๋ ์์ ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
dataloader_iterator = iter(dataloader)
for i in range(iterations):
try:
data, target = next(dataloader_iterator)
except StopIteration:
dataloader_iterator = iter(dataloader)
data, target = next(dataloader_iterator)
do_something()
๊ทธ๊ฒ์ ๋งค์ฐ ์ถ์ ํ์ง๋ง ์ ์๋ํฉ๋๋ค.
์ฌ๊ธฐ ์์ ์ ์ฌํ ๋ฌธ์ ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค. ์ด ์ค๋ ๋์์ ์ ์๋ ์๋ฃจ์ ์ด ์ฌ๋๋ค์๊ฒ๋ ๋์์ด ๋์์ผ๋ฉด ํฉ๋๋ค.
@ srossi93 ์ข์ ์๋ฃจ์
์
๋๋ค. ๋๋๋ก ๋ฐ๋ณต ์ฃผ๊ธฐ๊ฐ ๋๋ ๋ ๋ฌด์๋ ์์ธ๊ฐ ๋ฐ์ํฉ๋๋ค: ConnectionResetError: [Errno 104] Connection reset by peer
.
๋ค์ค ์ฒ๋ฆฌ๋ก ์ธํด ๋ฐ์ํ๋ ๊ฒ์ผ๋ก ๋ณด์ ๋๋ค. ๋ฐ์ดํฐ ๋ก๋์ num_workers๋ฅผ 0์ผ๋ก ์ค์ ํ๋ฉด ์ค๋ฅ๊ฐ ์ฌ๋ผ์ง๋๋ค. ๋ค๋ฅธ ์๋ฃจ์ ์ด ์์ต๋๊น?
thx @srossi93
์ด์ฉ๋ฉด ์ด ์ฝ๋๊ฐ ์กฐ๊ธ ๋ ๋์๊น์?
def inf_train_gen():
while True:
for images, targets in enumerate(dataloader):
yield images, targets
gen = inf_train_gen
for it in range(num_iters):
images, targets = gen.next()
์์ ์ ๊ณต๋ ์ฝ๋๋ฅผ ์ฌ์ฉํ๋ฉด ๋ฐ์ดํฐ ์ธํธ๊ฐ ์์ด๋์?
dataloader_iterator = iter(dataloader)
for i in range(iterations):
try:
X, Y = next(dataloader_iterator)
except:
dataloader_iterator = iter(train_loader)
X, Y = next(dataloader_iterator)
do_backprop(X, Y)
์ด๊ฒ์ด ์ด๋ฏธ ์์ฑ๊ธฐ/๋ฐ๋ณต์๊ฐ ์๋ ์ด์ ๋ ๋ฌด์์ ๋๊น?
@Yamin05114 iter(dataloader)์ ๊ฒฐ๊ณผ๊ฐ ์ฌ์ค์ ๋ ๋๋ง๋ค ์ ํ๋๋์ง ํ์ธํ๊ธฐ ์ํด ์์ ์์ ๋ฅผ ์คํํ์ต๋๋ค. ์๋์์ ์ด ์์ ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๋ฉด ์์ ์ธ์ ๋ฌธ ์ธํธ๋ฅผ ๋ณด๊ณ ์ฃผ๋ฌธ์ด ์ค์ ๋ก ์์ธ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค. ์ด๊ฒ์ ์ผ๋ฐ์ ์ธ ์ฆ๊ฑฐ๋ ์๋์ง๋ง iter(train_loader)๋ฅผ ํธ์ถํ ๋๋ง๋ค ๋ฐ์ดํฐ๊ฐ ๋ค์์ธ๋ค๋ ์ค๋๋ ฅ ์๋ ์ฆ๊ฑฐ์ ๋๋ค.
import torch
from torch.utils.data import Dataset, DataLoader
dataset = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=3)
iterloader = iter(dataloader)
for i in range(0, 12):
try:
batch = next(iterloader)
except StopIteration:
iterloader = iter(dataloader)
batch = next(iterloader)
print("iteration" + str(i))
print(batch)
๋ํ @shaibagon ์ ์ค๋ฅ๋ฅผ ์ฌํํ ์ ์์์ต๋๋ค...์๋ ์ฝ๋๋ ๋ณ๊ฐ์ ๋ฐฐ์น(์์์ ์ ์ํ ๊ฒ๊ณผ ๋์ผํ ๋ณ์ ์ฌ์ฉ)๋ฅผ ์์ฑํ๋ ๊ฒ ๊ฐ์์ ๊ฑฐ๊ธฐ์ ๋ฌด์จ ์ผ์ด ์ผ์ด๋ฌ๋์ง ํ์คํ์ง ์์ต๋๋ค.
for i in range(0, 12):
batch = next(iter(dataloader))
print("iteration: " + str(i))
print(batch)
pytorch์์ data.Dataset
๋ฅผ ์์ํ๋ dataset
๊ฐ์ฒด๊ฐ ์๋ ๊ฒฝ์ฐ idx๋ฅผ ์ธ์๋ก ์ฌ์ฉํ๋ __getitem__
๋ฉ์๋๋ฅผ ์ฌ์ ์ํด์ผ ํฉ๋๋ค. ๋ฐ๋ผ์ ์ง์ ์ก์ธ์คํ ์ ์์ต๋๋ค.
**some dataset instance called _data_
data=Dataset(**kwargs)
for i in range(10):
data[i]
๋๋
for i in range(10):
data_batch.__getitem__(i)
๊ฐ์ฅ ์ ์ฉํ ๋๊ธ
next(iter(data_loader))
?