ํธ์ง์ ์ฐธ๊ณ ์ฌํญ: ์ด ๋ฌธ์ ์ ๋ํด ์๋ ค์ง ํด๊ฒฐ ๋ฐฉ๋ฒ์ Python ๋ชฉ๋ก์ ์ฌ์ฉํ์ง ์๊ณ ๋์ numpy ๋ฐฐ์ด ๋๋ ํ ์๋ฅผ ์ง์ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค.
DataLoader num_workers > 0
์ด๋ฉด CPU ๋ฉ๋ชจ๋ฆฌ ๋์๊ฐ ๋ฐ์ํฉ๋๋ค.
๋ค์ ์ค๋ํซ์ ์คํํฉ๋๋ค.
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import os
class DataIter(Dataset):
def __init__(self):
path = "path/to/data"
self.data = []
for cls in os.listdir(path):
for img in os.listdir(os.path.join(path, cls)):
self.data.append(os.path.join(path, cls, img))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
with Image.open(self.data[idx]) as img:
img = img.convert('RGB')
return transforms.functional.to_tensor(img)
train_data = DataIter()
train_loader = DataLoader(train_data, batch_size=300,
shuffle=True,
drop_last=True,
pin_memory=False,
num_workers=18)
for i, item in enumerate(train_loader):
if i % 200 == 0:
print(i)
CPU ๋ฉ๋ชจ๋ฆฌ๋ ์ ์ฐจ ์ฆ๊ฐํ๊ธฐ ์์ํ์ฌ ๊ฒฐ๊ตญ ์ ์ฒด RAM์ ์ฑ์๋๋ค. ์๋ฅผ ๋ค์ด, ํ๋ก์ธ์ค๋ ์ฝ 15GB์์ ์์ํ์ฌ ์์คํ
์์ ์ฌ์ฉ ๊ฐ๋ฅํ ์ ์ฒด 128GB๋ฅผ ์ฑ์๋๋ค.
num_workers=0
์ผ ๋ RAM ์ฌ์ฉ๋์ ์ผ์ ํฉ๋๋ค.
PyTorch version: 1.0.0.dev20181028
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 16.04.4 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1
Python version: 3.5
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti
GPU 2: GeForce GTX 1080 Ti
Nvidia driver version: 390.67
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.4
Versions of relevant libraries:
[pip] Could not collect
[conda] Could not collect
PIL.__version__
'5.3.0'
๋ฐ์ดํฐ ์ธํธ์๋ ์ฝ 2,400๋ง ๊ฐ์ ์ด๋ฏธ์ง๊ฐ ์์ผ๋ฉฐ ๋ชจ๋ ์ด๋ฏธ์ง ๊ฒฝ๋ก๋ ์์ ์ฝ๋ ์ค๋ํซ์ ํ์๋ ๋๋ก ๋จ์ผ ๋ชฉ๋ก์ ๋ก๋๋ฉ๋๋ค.
๋ํ ์ฌ๋ฌ Pytorch(0.4.0 ๋ฐ 0.4.1) ๋ฒ์ ์ ์๋ํ์ง๋ง ํจ๊ณผ๋ ๋์ผํฉ๋๋ค.
cc @ezyang @gchanan @zou3519 @SsnL
๋ฐ๋ณตํ ๋ ๋๋ ๋ฐ๋ณต์ ์์ํ๊ธฐ ์ ์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ์ฆ๊ฐํ๋ ๊ฒ์ด ๋ณด์ ๋๊น?
@SsnL ๋ฐ๋ณต ์ค์๋ง.
#13243์ ์์ ํ ๋ ์ด๊ฒ๋ ์์ ๋์๋์ง ํ์ธํด์ผ ํฉ๋๋ค.
num_workers>0
batch_sampler
๋ฅผ ์ฌ์ฉํ ๋ OOM์ด ํธ๋ฆฌ๊ฑฐ๋ ๋๊น์ง ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ์ง์์ ์ผ๋ก ์ฆ๊ฐํ๋ ๋น์ทํ ์ํฉ์ ๊ฒฝํํ์ต๋๋ค.
import math
from torch.utils.data import DataLoader
class Sampler:
def __init__(self, n=100000, batch_size=32):
self.n = n
self.batch_size = batch_size
def __len__(self):
return math.ceil(float(self.n)/self.batch_size)
def __iter__(self):
batch = []
for i in range(self.n):
batch.append(i)
if len(batch) == self.batch_size:
yield batch
batch = []
if batch:
yield batch
N = 100000000
train_data = list(range(N))
def ok():
train_sampler = Sampler(len(train_data))
train_loader = DataLoader(train_data,
num_workers=0,
batch_sampler=train_sampler)
for i, item in enumerate(train_loader):
if i % 10000 == 0:
print(i)
def leaky():
train_sampler = Sampler(len(train_data))
train_loader = DataLoader(train_data,
num_workers=8,
batch_sampler=train_sampler)
for i, item in enumerate(train_loader):
if i % 10000 == 0:
print(i)
print('Starting ok')
ok()
print('ok done, starting leaky()')
leaky()
print('leaky done')
$ python3 collect_env.py
Collecting environment information...
PyTorch version: 0.4.0
Is debug build: No
CUDA used to build PyTorch: 9.1.85
OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1
Python version: 3.5
Is CUDA available: Yes
CUDA runtime version: 9.1.85
GPU models and configuration: GPU 0: GeForce GTX 1050 Ti with Max-Q Design
Nvidia driver version: 390.77
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.2
/usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
Versions of relevant libraries:
[pip] Could not collect
[conda] Could not collect
@ezyang
#13243์ ์์ ํ ๋ ์ด๊ฒ๋ ์์ ๋์๋์ง ํ์ธํด์ผ ํฉ๋๋ค.
์ด ๋ฌธ์ ๋ 1.0.0.dev20181105
์ ์ฌ์ ํ ์กด์ฌํ๋ฉฐ, ์ฌ๊ธฐ์ #13243์ด ์์ ๋์์ต๋๋ค.
์ข ๋ ์กฐ์ฌํ ํ์ ๋์ถ์ด ๋ฐ์ํ๋ ์ ํํ ์๋๋ฆฌ์ค๋ฅผ ์ฐพ์์ต๋๋ค. ์๋ ์ฝ๋ ์์ ๋ฅผ ๊ณ ๋ คํ์ญ์์ค.
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
class DataIter(Dataset):
def __init__(self):
self.data_np = np.array([x for x in range(24000000)])
self.data = [x for x in range(24000000)]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
data = np.array([data], dtype=np.int64)
return torch.tensor(data)
train_data = DataIter()
train_loader = DataLoader(train_data, batch_size=300,
shuffle=True,
drop_last=True,
pin_memory=False,
num_workers=18)
for i, item in enumerate(train_loader):
if i % 1000 == 0:
print(i)
Python์ ํ์ค ์ ์ ๋ชฉ๋ก์ธ self.data
๋ณ์๋ฅผ ์ฌ์ฉํ๋ฉด ๋ฐ์ดํฐ ๋์๊ฐ ๋ฐ์ํฉ๋๋ค. ๊ทธ๋ฌ๋ self.data_np
๋ณ์๋ฅผ ์ฌ์ฉํ๋ฉด ๋์ผํ ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ง๊ณ ์์ง๋ง Numpy ๋ฐฐ์ด ํํ๋ก ๋์ด ์์ผ๋ฉด ๋์๊ฐ ๋ฐ์ํ์ง ์์ต๋๋ค.
๋ ๋ค๋ฅธ ๊ด์ฐฐ์ DataLoader
์ shuffle=False
์ธ ๊ฒฝ์ฐ ๋์ถ์ด ํจ์ฌ ๋ ์ฌ๊ฐํ๋ค๋ ๊ฒ์
๋๋ค.
๋น์ทํ ๋ฌธ์ ์ ์ง๋ฉดํ์ง๋ง ์ ๊ฒฝ์ฐ์๋ numpy ๋ฐฐ์ด์์๋ ๋ฐ์ํฉ๋๋ค. ์ ๋ Python 3.7 ๋ฐ PyTorch ์ผ๊ฐ ๋ฆด๋ฆฌ์ค๋ฅผ ์ฌ์ฉํ๊ณ ์์ต๋๋ค.
๋ฉํฐํ๋ก์ธ์ฑ์ด pytorch์ ํ๋ ์๋์์ ์ค์ ๋ก ์ด๋ป๊ฒ ์๋ํ๋์ง ๋ชจ๋ฅด์ง๋ง ์ฐ๋ฆฌ๋ fast.ai ํฌ๋ผ(https://forums.txt)์์ ์ด "๋ฉ๋ชจ๋ฆฌ ๋์" ๋ฌธ์ (์๋ง๋ ๋ฉ๋ชจ๋ฆฌ ๋์๊ฐ ์๋ ์๋ ์์ต๋๋ค!)์ ๋ํด ๊ด๋ฒ์ํ๊ฒ ๋ ผ์ํ์ต๋๋ค. fast.ai/t/runtimeerror-dataloader-worker-is-killed-by-signal/31277/55?u=marcmuc). ์ฌ๊ธฐ์ ์ฝ๊ฐ์ ํต์ฐฐ๋ ฅ์ ์ถ๊ฐํ ์ ์๋ ์๋น ๊ฒฐ๊ณผ(์ด๊ฒ์ด ์ ์ฉ๋์ง ์๋ ๊ฒฝ์ฐ ์๊ฒฌ์ ๋งํ์ญ์์ค!):
Python Multiprocessing: Python ์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ์ ์์์ Python ๊ฐ์ฒด(๋จ์ํ ๋ชฉ๋ก ํฌํจ)๋ฅผ ์ ์ฅํ ์ ์๋ ๋ฐฉ๋ฒ์ ์์ต๋๋ค. refcounts๋ ๋ฉ๋ชจ๋ฆฌ ํ์ด์ง ๋จ์๋ก ์ถ๊ฐ๋๋ฏ๋ก ์๋น๊ฐ ์ฒ์ฒํ ์ฆ๊ฐํฉ๋๋ค. ํ๋ก์ธ์ค(์์ ์)๋ ๊ฒฐ๊ตญ ๋ชจ๋ /๋๋ถ๋ถ์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋นํธ ๋จ์๋ก ๋ณต์ฌํ๊ฒ ๋๋ฏ๋ก ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํ๋ก ๋ฌธ์ ๊ฐ ๋ฐ์ํฉ๋๋ค. ์ด ๋์์ ๋ํ ๊ฐ์ฅ ์ข์ ์ค๋ช ์ ์ฌ๊ธฐ (SO)์ ๋๋ค.
๊ฐ๋ฅํ ํด๊ฒฐ์ฑ
:
์ง๊ธ์ฒ๋ผ ๋ฉํฐํ๋ก์ธ์ฑ ์ฌ์ฉํ๊ธฐ: ํ์ด์ฌ ๋ฉํฐํ๋ก์ธ์ฑ์ด ์ด๋ฌํ refcount ํจ๊ณผ ์์ด ์๋ํ๋ ค๋ฉด ํ๋ก์ธ์ค ํ์ด ์์ฑ๋๊ณ ์์
์๊ฐ ๋ถ๊ธฐ๋๊ธฐ ์ ์ ๊ฐ์ฒด๊ฐ "ํธํ"๋๊ณ multiprocessing.Array
๋ก ๋ํ๋์ด์ผ ํฉ๋๋ค. ์ด๊ฒ์ ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ค์ ๋ก ๊ณต์ ๋๊ณ ๊ธฐ๋ก ์ค ๋ณต์ฌ๊ฐ ๋ฐ์ํ์ง ์์์ ๋ณด์ฅํฉ๋๋ค. ์ด๊ฒ์ numpy ๋ฐฐ์ด์ ๋ํด ์ํํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช
ํ๊ณ ๊ทธ ์ด์ ๋ฅผ ๋ค์ ์ค๋ช
ํฉ๋๋ค. copy-on-write๊ฐ ์ด ๋ชจ๋ ๊ฒ์ ๋ถํ์ํ๊ฒ ๋ง๋ ๋ค๋ ์ด ์ข์ ๋ต๋ณ์ ์์ฑ์๋ผ๋ ์ผ๋ถ ์๋ชป๋ ์ง์ ์ ํผ๋ํ์ง ๋ง์ญ์์ค. ์ด๋ ์ฌ์ค์ด ์๋๋๋ค. ํ ์๊ฒฌ์ ๋ํ ๋ค์๊ณผ ๊ฐ์ด ์ง์ ํฉ๋๋ค.
"์ฐธ๊ณ ๋ก Python์์ fork()๋ ์ค์ ๋ก ์ก์ธ์ค ์ ๋ณต์ฌ๋ฅผ ์๋ฏธํฉ๋๋ค(๊ฐ์ฒด์ ์ก์ธ์คํ๋ ๊ฒ๋ง์ผ๋ก๋ ์ฐธ์กฐ ํ์๊ฐ ๋ณ๊ฒฝ๋๊ธฐ ๋๋ฌธ์ ๋๋ค)."
๋๋ pytorch ์ฌ์ฉ์ ์ดํดํ๋ torch.multiprocessing ๋๋กญ์ธ ๊ต์ฒด์ ์ต์ํ์ง ์์ง๋ง ํต์ฌ python refcount ๋ฌธ์ ๋ฅผ ์ ๊ฑฐํ ์ ์์ ๊ฒ์ด๋ผ๊ณ ๊ฐ์ ํฉ๋๋ค.
@mprostock torch.multiprocessing์ ์ฌ์ฉ์ ์ง์ ํผํด๋ฌ๊ฐ ์๋ ๋จ์ํ Python ๋ค์ค ์ฒ๋ฆฌ์
๋๋ค. ์ปค์คํ
ํผํด๋ฌ๋ torch.tensor
๋ฅผ ๋ง๋ ๋๋ง๋ค ์๋์ผ๋ก ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ก ์ด๋ํ๋ฏ๋ก ์ ์ด๋ torch.tensor
๊ฐ์ฒด์์๋ copy-on-write๊ฐ ๋ฐ์ํ์ง ์์ต๋๋ค.
์ค๋ช ๊ฐ์ฌํฉ๋๋ค! ๋๋ @bfreskura ์ ์ฌ์์ฐ ์์ ๋ฅผ ์คํํ๊ณ ์ด์ ๋ฌธ์ ๋ฅผ ์ ํํ ์ง์ ํ ์ ์๋ค๊ณ ์๊ฐํฉ๋๋ค.
์์ bfreskura์ ์ฌ์์ฐ ์์ ๋ ์ผ๋ฐ ํ์ด์ฌ ๋ชฉ๋ก๊ณผ numpy ๋ฐฐ์ด์ ์ฐจ์ด์ ์ ๋ณด์ฌ์ฃผ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ฌธ์ ๋ ํ์ด์ฌ ๋ชฉ๋ก ์์ฒด์๋ง ์๋ ๊ฒ์ด ์๋๋ผ ๊ฐ์ฒด ์ ํ์ numpy ๋ฐฐ์ด์์๋ ๋์ผํ๊ฒ ๋ฐ์ํฉ๋๋ค. Python ๋ชฉ๋ก์ ๊ฐ์ฒด์ ๋ํ ์ฐธ์กฐ๋ง ์ ์ฅํ๊ณ ๊ฐ์ฒด๋ ๋ฉ๋ชจ๋ฆฌ์ ๋ณ๋๋ก ๋ณด๊ด๋ฉ๋๋ค. ๋ชจ๋ ๊ฐ์ฒด์๋ refcount๊ฐ ์์ผ๋ฏ๋ก ๋ชฉ๋ก์ ๋ชจ๋ ํญ๋ชฉ์๋ refcount๊ฐ ์์ต๋๋ค.
Numpy ๋ฐฐ์ด(ํ์ค np ์ ํ)์ ๋ฉ๋ชจ๋ฆฌ์ ์ฐ์ ๋ธ๋ก์ผ๋ก ์ ์ฅ๋๋ฉฐ ํ๋์ ์ฐธ์กฐ ์นด์ดํธ๊ฐ ์๋ ํ๋์ ๊ฐ์ฒด์ผ ๋ฟ์ ๋๋ค.
์ด๊ฒ์ numpy ๋ฐฐ์ด์ ๊ฐ์ฒด ์ ํ์ผ๋ก ๋ช ์์ ์ผ๋ก ๋ง๋ค๋ฉด ๋ณ๊ฒฝ๋์ด ์ผ๋ฐ ํ์ด์ฌ ๋ชฉ๋ก์ฒ๋ผ ๋์ํ๊ธฐ ์์ํฉ๋๋ค((๋ฌธ์์ด) ๊ฐ์ฒด์ ๋ํ ์ฐธ์กฐ๋ง ์ ์ฅ). ๋ฉ๋ชจ๋ฆฌ ์๋น์ ๋์ผํ "๋ฌธ์ "๊ฐ ์ด์ ๋ํ๋ฉ๋๋ค.
์ด๊ฒ์ ์ผ๋ฐ ๋ชฉ๋ก(๋๋ ๊ฐ์ฒด ์ ํ์ numpy ๋ฐฐ์ด)์์ "๋ฉ๋ชจ๋ฆฌ ๋์"๋ฅผ ๋ณด๋ ์ด์ ๋ฅผ ์ค๋ช ํฉ๋๋ค. ์ด๋ ์ค์ ๋ก ๋ฉ๋ชจ๋ฆฌ ๋์๊ฐ ์๋๋ผ ์ฐธ์กฐ ์นด์ดํธ ๋ณ๊ฒฝ์ผ๋ก ์ธํ ๋ถ๊ธฐ๋ ํ์ด์ฌ ํ๋ก์ธ์ค์ ์ก์ธ์ค ์ ๋ณต์ฌ ๋ฌธ์ ์ ๋๋ค.
๋ฐ๋ผ์ ๋ฌธ์ ๋ ์๋ง๋ (์ข ์ข ) ํ ์ ๋๋ ์ค์ ํ ์น ๊ฐ์ฒด์ ์๋ฌด ๊ด๋ จ์ด ์์ผ๋ฉฐ, ์คํ๋ ค ๋ฐ์ดํฐ ๋ก๋/๋ฐ์ดํฐ ์ธํธ ๋ด์์ ์ผ๋ฐ์ ์ผ๋ก ์ฌ์ฉ๋๋ ํ์ผ ์ด๋ฆ ๋ฐ ๋ ์ด๋ธ ์ฌ์ ๋ชฉ๋ก๊ณผ ๊ด๋ จ์ด ์์ต๋๋ค.
๋๊ตฐ๊ฐ๊ฐ ๋นจ๋ฆฌ ๊ทธ๊ฒ์ ์๋ํ๊ณ ์ถ๋ค๋ฉด ๋๋ ๋
ธํธ๋ถ ์์ง ๋ฅผ ๋ง๋ค์์ต๋๋ค.
๋ฉ๋ชจ๋ฆฌ ์๋น๋ฅผ ์ดํด๋ณด์ญ์์ค(์ ์ฒด ์์คํ
์ ๋น ๋ฅด๊ณ ๋ํฐํ ๋ฉ๋ชจ๋ฆฌ์ด๋ฏ๋ก ๋ค๋ฅธ ํ๋ก์ธ์ค์ ์์ ์ํฅ์ผ๋ก ์์คํ
์ ๊นจ๋ํ๊ฒ ์ ์งํ๋ ค๊ณ ํ์ต๋๋ค)
๊ณ ์ ๊ธธ์ด ๋ฌธ์์ด ๋ฐฐ์ด์ ๋ฉ๋ชจ๋ฆฌ ์๋น(GB):
๊ฐ์ฒด ๋ฐฐ์ด์ ๋ฉ๋ชจ๋ฆฌ ์๋น(GB)(๋ณ๊ฒฝ๋ง ๊ฐ๋ฅ)
๋๋ ๊ฐ์ ๋ฌธ์ ์ ์ง๋ฉดํ๊ณ ์์ต๋๋ค. num_workers > 0์ด๋ฉด RAM์ ๋งค์ฐ ๋น ๋ฅด๊ฒ ์ฑ์๋๋ค.
๋ด ์ฝ๋์์ ๋ ์ด์ ํ์ํ์ง ์๋ค๊ณ ์๊ฐ๋๋ ๋ณ์๋ฅผ ์ญ์ ํ๊ณ ๋ชจ๋ ๋ฐ๋ณต์์ gc.collect()๋ฅผ ํธ์ถํ์ง๋ง ์๋ฌด ๊ฒ๋ ๋์์ด ๋์ง ์์ต๋๋ค.
ํด๊ฒฐ ๋ฐฉ๋ฒ์ด ์์ต๋๊น?
dict์์ pandas๋ก, ๋ชฉ๋ก์์ numpy ๋ฐฐ์ด๋ก ์ ํํ๋ฉด ๋์์ด ๋ฉ๋๋ค.
๋๋ ๊ฐ์ ๋ฌธ์ ์ ์ง๋ฉดํ๊ณ ์์ต๋๋ค. num_workers > 0์ด๋ฉด RAM์ ๋งค์ฐ ๋น ๋ฅด๊ฒ ์ฑ์๋๋ค.
๋ด ์ฝ๋์์ ๋ ์ด์ ํ์ํ์ง ์๋ค๊ณ ์๊ฐ๋๋ ๋ณ์๋ฅผ ์ญ์ ํ๊ณ ๋ชจ๋ ๋ฐ๋ณต์์ gc.collect()๋ฅผ ํธ์ถํ์ง๋ง ์๋ฌด ๊ฒ๋ ๋์์ด ๋์ง ์์ต๋๋ค.
ํด๊ฒฐ ๋ฐฉ๋ฒ์ด ์์ต๋๊น?
๋ต์ฅ์ ๋ณด๋ด ์ฃผ์ ์ ๊ฐ์ฌํฉ๋๋ค. ๋๋ ๊ทธ๊ฒ์ ์๋ํ๊ณ ํฌ๋ง์ ์ผ๋ก ์๋ํฉ๋๋ค.
์ด ๋ฌธ์ ์ ๋ํ ํด๊ฒฐ์ฑ ์ ์์ฒญํ ์ ์์ต๋๊น? ๋ง์ง๋ง์ผ๋ก ๋งค์ผ ๋น๋๋ pytorch์์ @samgd ์ฝ๋๋ฅผ ์๋ํ์ง๋ง ์ฌ์ ํ ๋์ถ๋๊ณ ์์์ต๋๋ค.
@Godricly ์์ @mprostock ๋ฐ @soumith ์ ์๊ฒฌ์ ์ฐธ์กฐํ์ญ์์ค. ์ด๊ฒ์ ์ค์ ๋ก ๋์ถ์ด ์๋์ง๋ง python ๊ธฐ๋ณธ ๋ชฉ๋ก์ ์ฌ์ฉํ๋ ๋ถํํ ๋์์ ๋๋ค. ํ ์น ํ ์ ๋๋ np ์ด๋ ์ด๋ฅผ ์ฌ์ฉํ๋ฉด ์ด ๋ฉ๋ชจ๋ฆฌ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์ ์์ต๋๋ค.
@mprostock ๋ค๋ฅธ ๊ฒ์ด ์๋๋ผ ์ก์ธ์ค ์ ๋ณต์ฌ๋ก ์ธํด ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ์์ฑ๋ ๋ณต์ฌ๋ณธ์ ์๋ฏธํฉ๋๊น? ๊ทธ๋ฆฌ๊ณ ์ฌ์ฉ ํ์๋ ๋ณต์ฌ๊ฐ ํด์ ๋์ง ์์ต๋๊น?
๋๊ตฐ๊ฐ๋ ์ต์ํ ์ด๋ฏธ์ง ๋ฐ์ดํฐ ์ธํธ์ ๋ํ ์ ์ ํ ์ฆ๊ฐ ์์
์ ์์ฑํ๊ณ ๋จ๊ณ๋ฅผ ๋์ฌ์ผ ํฉ๋๋ค. ์ด๋ฌํ ๋ชจ๋ ๋ค์ค ์ฒ๋ฆฌ ์์์์ ๋ํ ๋ชจ๋ ์ด์ ๋ ๋น์ ๋ฐ์ดํฐ ์ธํธ๊ฐ ๋ค์ค ์ฝ์ด์์ ์ด๋ฏธ์ง๋ฅผ ๋์ฝ๋ฉํ๊ณ ์๋ฅด๊ธฐ _๊ฐ์ ธ ์๊ธฐ ๋๋ฌธ์
๋๋ค. ๋์ฝ๋ฉ ๋ฐ ๊ธฐํํ์ ์ด๋ฏธ์ง ๋ณํ(ํฌ๊ธฐ ์กฐ์ , ์๋ฅด๊ธฐ ๋ค์ง๊ธฐ, ์ ๋จ, ์ํ)์ ์ฒ๋ฆฌํ๊ณ ๋ฐฐ์น ํ
์๋ฅผ ์ง์ ์์ฑํ๋ ์์
์ด ์๋ค๋ฉด ๋ค์ค ์ฒ๋ฆฌ๋ฅผ ์ ํ ์ฌ์ฉํ ํ์๊ฐ ์์ผ๋ฉฐ ๋ ๋์๊ฐ ๋น๊ธฐํํ์ ์ฆ๊ฐ ๋จ๊ณ๋ฅผ ์ฌ์ฉํ ํ์๊ฐ ์์ต๋๋ค. (์์, ๋ฏธ๋ฐฑ/์ ๊ทํ, ๋
ธ์ด์ฆ)๋ ๋ด๋ถ ์ฐ์ฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ์ ์ฒด ํ
์๋ฅผ ์ฐข์ ์ ์์ต๋๋ค. ์ฃผ์(๊ฒฝ๊ณ ์์, ๋ง์คํฌ, ํคํฌ์ธํธ ๋ฑ)์ ๋ณ๋ ฌ ๋ณํ์ ๊ฐ๋ฅํ๊ฒ ํ๊ธฐ ์ํด ํ
์์ ๊ฐ ์ํ์ ๋ํ ๋ณํ ๋งค๊ฐ๋ณ์๋ฅผ ์ธ๋ถ์ ๋
ธ์ถํ๋๋ก ์ด๋ฌํ ์ฐ์ฐ์ ์ค๊ณํ ๋ ์ฃผ์๋ฅผ ๊ธฐ์ธ์ฌ์ผ ํฉ๋๋ค.
๋๋ ๋ ๋์ ๋ฐฉ๋ฒ์ ์ด ์๋ฒ๋ฅผ ์ฌ๋ฌ ํ๋ก์ธ์ค(๋ฐ ๋ค๋ฅธ DL ํ๋ ์์ํฌ)์์ ์ฌ์ฉํ ์ ์๋๋ก ํ๋ ๊ฒ์
๋๋ค.
@mprostock ํ๋ฅญํ ์ค๋ช ๊ฐ์ฌํฉ๋๋ค!
๊ทธ๋ฌ๋ ์์ง ํด๊ฒฐ์ฑ ์ด ์ ์๋์ง ์์์ต๋๋ค. Dataset ๊ฐ์ฒด์ ํ์ผ ์ด๋ฆ ๋ชฉ๋ก์ ์ ์ฅํ๋ ๊ฒ์ด ๊ณต์ ํด ๋ณด์ด์ง๋ง ์ด๋ป๊ฒ ์ฌ์ฉํ ์ ์์ต๋๊น? ๋๊ตฌ๋ ์ง ๊ทธ๊ฒ์ ์์ ๋์ต๋๊น?
@1e100 @fmassa ๊ฐ torchvision
์ ๋ค์ดํฐ๋ธ ์ด๋ฏธ์ง ํ๋ ์์
์ ์ถ๊ฐํ๋ ์์
์ ํ๊ณ ์๋ค๊ณ ์๊ฐํฉ๋๋ค. ์ด๋ ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๋ฐ ๋์์ด ๋ ๊ฒ์
๋๋ค.
์ด ๋ฌธ์ ์ ๋ํ ์ ๋ฐ์ดํธ๊ฐ ์์ต๋๊น?
๋ง์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ๋ฉด ๋ฌธ์ ๊ฐ ํด๊ฒฐ๋์์ต๋๋ค. ๋ค์์ ๋์ปค ์ปจํ ์ด๋ ๋ด๋ถ์์ ์ฝ๋๋ฅผ ์คํํ๊ณ ์๊ณ ๊ทธ๋ ์ง ์์ผ๋ฉด ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ค์ ํ ์ ์๋ ๊ฒฝ์ฐ ์คํฌ๋ฆฝํธ ๋ด๋ถ์์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ค์ ํ๋ ํดํน์ ๋๋ค.
os.system(f"mount -o remount,size={args.shared_memory_size} /dev/shm")
๊ณต์ ๋ฉ๋ชจ๋ฆฌ ํฌ๊ธฐ๋ ์๋ฅผ ๋ค์ด ์ด RAM์ ์ ๋ฐ์ด ๋ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด biiig ๋จธ์ ์ ๊ฒฝ์ฐ '80G'์ ๋๋ค.
๋ฉ๋ชจ๋ฆฌ๊ฐ ํน์ ์ง์ ๊น์ง _์ฌ์ ํ_ ํฌ๋ฆฝ์ง๋ง ํ์ฉ๋๋ ํ์ผ ์ค๋ช
์ ์๋ฅผ ๋ณ๊ฒฝํ์ฌ ์ด ๋ฌธ์ ์ ๊ด๋ จ๋ unable to open shared memory object </torch_22291_1137042840> in read-write mode
์ค๋ฅ ๋ฐ์์ ๋ํ ํด๊ฒฐ ๋ฐฉ๋ฒ์ ์ฐพ์์ต๋๋ค.
ํ์ฉ๋ ํ์ผ ์ค๋ช
์ ์๋ฅผ ํ์ธํ๋ ค๋ฉด bash์ ulimit -a
๋ฅผ ์
๋ ฅํ๋ฉด -n
ํ๊ทธ ์๋์ ํ์๋ฉ๋๋ค. ํ์ฌ ์
ธ์ ๋ํ ์ด ์ ํ์ ๋์ด๋ ค๋ฉด(์ฆ, ์๋ฒ์ ๋ํ ๊ถํ์ด ์๋ ๊ฒฝ์ฐ) ๋ค์์ ์คํํฉ๋๋ค.
๋ฐฐ์ฌ: ulimit -n NEW_VALUE
์ ์ฒด ์์คํ ์ ๋ํด ๋ณ๊ฒฝํ๋ ค๋ฉด ์ฌ๊ธฐ๋ฅผ ์ฐธ์กฐ ํ์ญ์์ค .
๋ฐ๋ผ์ ๋ด๊ฐ ์ฌ๋ฐ๋ฅด๊ฒ ์ดํดํ๋ค๋ฉด ์์
์ ํ๋ก์ธ์ค๋ ๋ชฉ๋ก์ ์ก์ธ์คํ ๋๋ง๋ค ๊ธด ํ์ผ ๊ฒฝ๋ก ๋ชฉ๋ก์ ๋ณต์ฌ๋ณธ์ ์์ฑํฉ๋๊น? ๊ทธ๋ฌ๋ ์ด ์์ ๋ณต์ฌ๋ณธ์ ํด๋น ํ๋ก์ธ์ค์ ๋ํ __getitem__
ํจ์๊ฐ ๋ฐํ๋์๋ง์ ๋ฒ์๋ฅผ ๋ฒ์ด๋(๊ฒฐ๊ณผ์ ์ผ๋ก ํ๊ดด๋จ) ๋์ง ์์ต๋๊น? RAM ์๋น๊ฐ ์ ํ ์์ด ์ฆ๊ฐํ๋ ์ด์ ๋ ๋ฌด์์
๋๊น?
๋๊ตฐ๊ฐ๊ฐ ์ด ๋ฌธ์ ๋ฅผ ํผํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ช ๊ฐ์ง ๋ชจ๋ฒ ์ฌ๋ก๊ฐ ํฌํจ๋ ์งง์ ๊ฐ์ด๋๋ฅผ ๋ง๋ค์๋ค๋ฉด ์ข์ ๊ฒ์ ๋๋ค. ์ซ์ ๊ฐ์ ์ฌ์ฉํ๋ฉด Python ๋ชฉ๋ก์ NumPy ๋ฐฐ์ด๋ก ์ฝ๊ฒ ๊ต์ฒดํ ์ ์์ง๋ง (๊ฐ๋ณ ํฌ๊ธฐ) ๋ฌธ์์ด ๋ฌธ์ ๋ฅผ ์ํํ๋ ๋ฐฉ๋ฒ์ด ๋ช ํํ์ง ์์ต๋๋ค.
์ ๊ฒฝ์ฐ์๋ ์์ฑ์์์ ์์ฑ/์ฑ์์ง ์ฌ์ฉ์ ์ ์ ํด๋์ค ๊ฐ์ฒด ๋ชฉ๋ก์ด ์์ต๋๋ค. ๊ธฐ๋ณธ์ ์ผ๋ก ํ์ผ ๊ฒฝ๋ก ์ธํธ๋ง ํฌํจํฉ๋๋ค. ๊ทธ๋ฐ ๋ค์ __getitem__
๋ด๋ถ์์ ํด๋น ์ด๋ฏธ์ง๋ฅผ ๋ก๋ํ๊ณ , ์ผ๋ถ ์ ์ฒ๋ฆฌ๋ฅผ ์ํํ๊ณ , ํ ์น ํ
์๋ก ๋ณํํ ๋ค์ ๋ฐํํ๊ธฐ ์ ์ ๋ก๋๋ ์ด๋ฏธ์ง์์ ๋ช
์์ ์ผ๋ก del
๋ฅผ ๊ณ์ฐํฉ๋๋ค. ๋ฌธ์ ๋ ๊ฒ๋ณด๊ธฐ์ ๋ฌดํดํด ๋ณด์ด๋ ์ ์ฒ๋ฆฌ ๋จ๊ณ๋ฅผ ์ถ๊ฐํ๋ฉด ์ด๋ฌํ ํ๊ณ๋ฅผ ๋ฒ์ด๋ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ ๋ฌธ์ ๊ฐ ๋ฐ์ํ๋ค๋ ๊ฒ์
๋๋ค.
Py 3.8์ mp.shared_memory
๋ ๋ง์ ๋นํ
์/nparray ๊ฐ์ฒด๋ฅผ ๊ณต์ ํ๋ ๋ฐ ์ถฉ๋ถํ ํด๊ฒฐ ๋ฐฉ๋ฒ์ ์ ๊ณตํ ์ ์์ต๋๋ค(์: ๊ณต์ ๋ชฉ๋ก: https://docs.python.org/3.8/library/multiprocessing.shared_memory.html). #multiprocessing.shared_memory.ShareableList. :)
๋ฉด์ฑ ์กฐํญ : ๋๋ ์ค์ ๋ก ๋ฌธ์๋ฅผ ์์ธํ ์ฝ์ง ์์์ต๋๋ค.
์ฌ๊ธฐ์ ์ฐ๋ฆฌ๊ฐ ํ ์ ์๋ ์กฐ์น๊ฐ ์์ต๋๊น? ์ผ๋ถ ์ฌ์ฉ ์ฌ๋ก๋ฅผ ํด๋น ์ฌ๋ก๋ก ์ฎ๊ธฐ๋ ๊ฒ์ ๋ฌธ์ํํ๊ธฐ ์ํด ํ ์น๋น์ ์์ ์ง์๋๋ ์ถฉ๋ถํ ์ด๋ฏธ์ง ๋ณํ์ด ์์ต๋๊น?
์ฌ๊ธฐ์์ ์์ ์ ๋ช ํํ ํ๊ธฐ ์ํด: @1e100 ์ด ์ ์ํ ๊ฒ์ ๊ตฌํํ๋ ๊ฒ์ ์ฐ๋ฆฌ๊ฐ torchvision์ ๋ก๋๋งต์ ํฌํจํ๊ณ ์๋ ๊ฒ์ด์ง๋ง, ์ฐ๋ฆฌ ๋ชฉ๋ก์ ๋งจ ์์ ์์ง ์์ผ๋ฉฐ ์๋ง๋ ๋จผ์ ์ค์ฒฉ ํ ์ ์ง์์ด ํ์ํ ๊ฒ์ ๋๋ค.
์ฆ, ์ด๊ฒ์ ์ด ๋ฌธ์ ์ ๋ํ ์ผ๋ฐ์ ์ธ ์์ ์ด ์๋๋๋ค. ๋ค๋ฅธ ์ ๊ทผ ๋ฐฉ์(์: GPU์ ๋ณํ)์ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ๋ก๋์์ ๋ค์ค ์ฒ๋ฆฌ์ ํ์์ฑ์ ์ฐํํ ๋ฟ์ ๋๋ค.
๋๊ตฐ๊ฐ๊ฐ ์ค์ฒฉ ํ ์๋ฅผ ์ธ๊ธํ๋ ๊ฒ์ ๋ณด์์ ๋ cc @cpuhrsch . (๊ทธ๋ฐ๋ฐ @cpuhrsch , ์ค์ฒฉ ํ ์์ ๋ํ ๋ชจ๋ ๋ ์ด๋ธ์ ๋ง๋ค๊ณ https://github.com/pytorch/pytorch/issues/24422 ์์ ์ถ๊ฐํ ์ ์์ต๋๊น?)
์ด ๋ฒ๊ทธ๊ฐ 1๋ ๋์ ํด๊ฒฐ๋์ง ์์ ์ด์ ๋ ๋ฌด์์ ๋๊น?
@IMLHF ์ด ๋ฌธ์ ์ค๋ช ์ ์ฒซ ๋ฒ์งธ ์ค ๋๋ ์์ ๋ ผ์๋ฅผ ์ฐธ์กฐํ์ญ์์ค. ์ด๊ฒ์ ์ค์ ๋ก ๋์ถ์ด ์๋๋ผ ์ฐ๋ฆฌ์ ์์ ๋ฒ์ด๋ ๋ถํํ ํ์ด์ฌ ๋์์ธ์ด๊ธฐ ๋๋ฌธ์ ๋๋ค. pytorch์ numpy๋ ๋ชจ๋ ํ ์ ๋ฐ ndarray์ ๋ํ ์ฌ์ฉ์ ์ง์ ์ง๋ ฌํ๋ฅผ ๊ตฌํํ์ฌ ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ค๊ณ ํ์ต๋๋ค. ๊ทธ๋ฌ๋ ์ฐ๋ฆฌ๋ ์ผ๋ฐ์ ์ธ ๋ฐ์ดํฐ ๊ตฌ์กฐ๋ฅผ ์ค๋ช ํ ์ ์์ต๋๋ค. ์ด๊ฒ์ ์ฌ์ฉ์๊ฐ ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์ ์๋๋ก ๋ ๋ง์ ์ ํธ๋ฆฌํฐ๋ฅผ ๊ตฌํํ๊ณ ์๊ธฐ ๋๋ฌธ์ ์ด๋ ค ์์ต๋๋ค.
๊ฐ ๋ฐ๋ณต ๋์ torch.cuda.empty_cache()
๋ฅผ ์ถ๊ฐํ๋ฉด ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๋ฐ ๋์์ด ๋ฉ๋๋ค. ์ด๊ฒ์ ์ถ๊ฐํ ํ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ๋์ ๋์ง ์๊ณ ๋ณ๋ํฉ๋๋ค.
์๋ง๋ ์ฐ๋ฆฌ๋ ๊ฒฝ๊ณ ๋ฅผ ์ถ๊ฐํด์ผ ํฉ๋๋ค.
@VitalyFedyunin ์ด๊ฒ์ ๋ณผ ๋์ญํญ์ด ์์ต๋๊น? ์ต์ํ ์ด๊ฒ์ด https://github.com/pytorch/pytorch/issues/17499์ ๋์ผํ ๋ฌธ์ ์ธ์ง ์์๋ผ ์ ์์ต๋๊น?
๋ด ํ๋ก์ ํธ์์ ํ ์๋ฅผ ์ฌ์ฉํ๋ ๋์ ndarray๋ฅผ ์ฌ์ฉํ์ฌ ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ค๊ณ ์๊ฐํฉ๋๋ค.
๋ด ์ด์ ์ฝ๋๋
def df2var(x):
return (torch.LongTensor(token2id(x['Query'], max_char = max_length_char)),
torch.tensor(coll2id[x['Agg_Coll']], dtype = torch.long))
class Making_Dataset(Dataset):
def __init__(self, input_dataframe):
self.dataset = input_dataframe.apply(lambda x : df2var(x), axis = 1)
def __len__(self):
return len(self.dataset)
def __getitem__(self, data_index):
return self.dataset[data_index]
๊ทธ๋ฆฌ๊ณ ๋ค์๊ณผ ๊ฐ์ด ์ฝ๋๋ฅผ ์์ ํ์ต๋๋ค.
class Making_Dataset(Dataset):
def __init__(self, input_dataframe):
self.text = np.array([token2id(q, max_char = max_length_char) for q in input_dataframe.Query])
self.labels = np.array([coll2id[coll] for coll in input_dataframe.Agg_Coll])
def __len__(self):
return len(self.text)
def __getitem__(self, data_index):
return self.text[data_index], self.labels[data_index]
์ฝ๋๋ฅผ ์์ ํ ํ ๊ฐ ์ํฌํฌ์ ๋ฉ๋ชจ๋ฆฌ ์ฆ๊ฐ ๋ฌธ์ ๊ฐ ๋ด ํ๋ก์ ํธ์์ ์ฌ๋ผ์ก์ต๋๋ค.
์ด ๋ฌธ์ ์ ์์ธ์ด ๋ฌด์์ธ์ง ์ ํํ ๋ชจ๋ฅด๊ธฐ ๋๋ฌธ์ ์ด์ ๋ํ ๋ชจ๋ ์๊ฒฌ์ ํ์ํฉ๋๋ค!
Ubuntu 18.04์ CUDA 10์ด ํฌํจ๋ Torch 1.3.0๊ณผ ์ ์ฌํ ๋ฌธ์ ๊ฐ ์์ต๋๋ค. ์ด๊ฒ์ 64GB RAM์ด ์๋ AWS ์์คํ ์์๋ ๋ฌธ์ ๊ฐ ๋์ง ์์์ง๋ง 128GB RAM ๋ฐ 128GB ์ค์์ด ์๋ ๋ก์ปฌ ์์คํ ์์๋ 150 ์ํฌํฌ๋ ํต๊ณผํ ์ ์์ต๋๋ค. ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ๋ช GB(์์)์์ 128GB๋ก ๊ณ์ ์กฐ๊ธ์ฉ ๋์ด๋ฉ๋๋ค + GB.
๋ด ๋ฌธ์ ๋ ๊ตํํ ์์ ๋ฒ๊ทธ์์ต๋๋ค. ๊ต์ก ํต๊ณ๋ฅผ ๊ธฐ๋กํ๋ ๋์ ์์ ๊ฐ๊ณผ ํจ๊ป ๊ธฐ์ธ๊ธฐ ์ ๋ณด๋ฅผ ์ ์ฅํ๊ณ ์์์ต๋๋ค. ์ด ์ ๋ณด๋ ๋ถํ์ํ๊ณ ๊ฐ ์ํฌํฌ์ ๋ฉ๋ชจ๋ฆฌ ๊ณต๊ฐ์ ์ถ๊ฐ๋ฉ๋๋ค.
Py 3.8์ mp.shared_memory๋ ๋ง์ ๋นํ ์/nparray ๊ฐ์ฒด(์: ๊ณต์ ๋ชฉ๋ก)๋ฅผ ๊ณต์ ํ๋ ๋ฐ ์ถฉ๋ถํ ํด๊ฒฐ ๋ฐฉ๋ฒ์ ์ ๊ณตํ ์ ์์ต๋๋ค.
https://github.com/pytorch/pytorch/issues/13246#issuecomment -513480017
๋ง์ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ๋ฌธ์ ๊ฐ ํด๊ฒฐ๋์์ต๋๋ค.
https://github.com/pytorch/pytorch/issues/13246#issuecomment -487042977
์๋
ํ์ธ์, ์ด ์ฃผ์ ์ ๋ํด ์กฐ๊ธ ๋ฆ์์ง๋ง ์ฌ์ ๊ณผ ๋์ผํ ๋ฌธ์ ์ ์ง๋ฉดํด ์์ต๋๋ค.
์ด ํดํน์ผ๋ก ์ฑ๊ณตํ ์ฌ๋์ด ์์ต๋๊น?
์ด๊ฒ์ ์ฌ์ ํ โโ์ ํจํ ๋ฌธ์ ์ ๋๋ค. ๋๊ตฐ๊ฐ๊ฐ ๋ฉ๋ชจ๋ฆฌ ๋์๋ฅผ ์ผ์ผํค์ง ์๊ณ DataLoaders์์ ์ฌ๋ฌ ์์ ์๋ฅผ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ชจ๋ฒ ์ฌ๋ก ๋ชฉ๋ก์ ์ ๊ณตํ ์ ์์ต๋๊น?
@marrrcin ๋ด ์๊ฐ์ ๊ฐ์ฅ ์ข์ ๋ฐฉ๋ฒ์ ํ ์๋ฅผ ๋น์ผ ๊ฒ์ผ๋ก ๊ฐ์ฃผํ๋ ๊ฒ์ด๋ฏ๋ก, ํนํ ํ ์์ ๊ทธ๋ผ๋์ธํธ ์ ๋ณด๊ฐ ์์ ๊ฐ๋ฅ์ฑ์ด ์๋ ๊ฒฝ์ฐ ์ฌ์ฉ ๋น๋์ ์ฃผ์ํด์ผ ํฉ๋๋ค.
์๋ฅผ ๋ค์ด torch
์์
์ ์ํํด์ผ ํ ๋๊น์ง ๋ชจ๋ ํญ๋ชฉ์ ๋ชฉ๋ก ๋๋ numpy.ndarray
๋ก ์ ์ฅํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
@AudreyBard๋ ๋ต๋ณ ๊ฐ์ฌํฉ๋๋ค. ๋ด ๋ฐ์ดํฐ ์ธํธ ์ฝ๋์๋ ๋ชจ๋ ๊ฒ์ด numpy/lists/strings/int๋ก ์ ์ฅ๋์ด ์์ผ๋ฉฐ ํ
์๋ฅผ ์ฌ์ฉํ๋ ์ ์ผํ ๋ถ๋ถ์ __getitem__
์ด๊ณ ๋์ค์ collate_fn
(ํจ๋ฉ ์ ์ฉ)์
๋๋ค. requires_grad
๊ฐ false๋ก ์ค์ ๋ ํ
์๋ฅผ ์์ฑํด์ผ ํฉ๋๊น? ๋ด ์ฝ๋๊ฐ num_workers>0์ผ๋ก ๋ค์ด๊ฐ๋ฉด ๋ฉ๋ชจ๋ฆฌ ๋์๊ฐ ์์๋ฉ๋๋ค.
๋ชจ๋ฐ์ผ์ด๋ผ ํ์์ด ์๋ง์์ ์ฃ์กํฉ๋๋ค.
@marrrcin ๋๋ ์ผ๋ฐ์ ์ผ๋ก __getitem__
์ tensor
๋ก ์
๋ ฅ ๋ฐ์ดํฐ(์ด๋ฏธ์ง ๋๋ ์ ํธ ๋ฑ)๋ง ์บ์คํธํฉ๋๋ค. ๋ด ๋ ์ด๋ธ ๋ฑ์ ์ผ๋ฐ์ ์ผ๋ก ๋ชฉ๋ก์ผ๋ก ๋ฐํ๋ฉ๋๋ค. ์ด๋ค ์ข
๋ฅ์ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ๊ณ ์๋์ง ๋๋ ํน๋ณํ ์ข
๋ฅ์ ํจ๋ฉ์ ์ฌ์ฉํ๊ณ ์๋์ง ์ ๋ชจ๋ฅด๊ฒ ์ง๋ง ์ผ๋ฐ์ ์ผ๋ก __getitem__
torchvision.transforms
#$ ๋ฅผ ์ฌ์ฉํฉ๋๋ค. ๊ทธ๋งํ ๊ฐ์น๊ฐ ์๊ธฐ ๋๋ฌธ์ ์ฌ์ฉ์ ์ ์ collate_fn
๋ฅผ ๊ตฌํํ๋ ๊ฒฝ์ฐ๋ ๊ฑฐ์ ์์ต๋๋ค.
์๊ฐ: ๋๋ ๋ฉ๋ชจ๋ฆฌ ๋์๋ผ๊ณ ์๊ฐํ๋ ๊ฒ์ ๊ฒฝํํ๊ณ ์์๊ธฐ ๋๋ฌธ์ ์๋ ์ฌ๊ธฐ์ ๊ฒ์ํ์ต๋๋ค. ๋งค ์ํฌํฌ๋ง๋ค ๋ถํ์ํ ๋ฐ์ดํฐ์ ๋งค๋ฌ๋ฆฌ๊ณ ์์๊ณ , ์ ์ ์ฅ์์๋ ์ ๋ง ๋ฏธ๋ฌํ ๋ณ์ ๊ด๋ฆฌ์์ ๋ ๋์ ์ฆ์์ด ๋ํ๋ฌ์ต๋๋ค. ๋ฌด์จ ์ผ์ด ์ผ์ด๋๊ณ ์๋์ง ์ ํํ ํ์ ํ๋ ๋ฐ ์๊ฐ์ด ๊ฑธ๋ ธ์ต๋๋ค.
@AudreyBeard ์ ๊ฒฝ์ฐ๋ ์ด๋ฏธ์ง/torchvision๊ณผ ๊ด๋ จ์ด ์์ต๋๋ค. ๊ฐ๋ณ ๊ธธ์ด ํ
์คํธ์์ ์ถ์ถํ ํ ํฐ์ ํจ๋ฉ์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ collate_fn
์ฌ์ฉํด์ผ ํฉ๋๋ค.
class PaddingCollateFn:
def __call__(self, batch):
sorted_batch = sorted(batch, key=lambda x: x[0].shape[0], reverse=True)
sequences = [x[0] for x in sorted_batch]
sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)
attention_masks = [torch.tensor([1 for _ in x[0]]) for x in sorted_batch]
attention_masks_padded = torch.nn.utils.rnn.pad_sequence(
attention_masks, batch_first=True
)
lengths = torch.tensor([len(x) for x in sequences])
labels = torch.tensor([x[1] for x in sorted_batch])
return (sequences_padded, lengths, attention_masks_padded), labels
ํจ๋ฉ ํ ์๋ณธ ํ
์๋ฅผ ์ญ์ ํด์ผ ํ๋์(์ del
์ฌ์ฉ)? collate_fn
๊ฐ ์๋ฃ๋๋ฉด ํด๋น ํญ๋ชฉ์ ๋ํ ์ฐธ์กฐ๊ฐ ์์ผ๋ฏ๋ก ๋ฒ์๋ฅผ ๋ฒ์ด๋ ์ ๊ฑฐ๋ ๊ฒ์ด๋ผ๊ณ ์๊ฐํ์ต๋๋ค.
๋๋ ์ด๊ฒ์ Pytorch ๋ฒ์ 1.3.1์์ ๋ง๋ฌ์ต๋๋ค.... ImageNet์ ํ๋ จํ ๋....๋๊ตฐ๊ฐ ์์ด๋์ด๊ฐ ์์ต๋๊น?
์ ๊ฒฝ์ฐ์๋ 24๊ฐ์ num_workers๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ผ๋ฐ์ ์ผ๋ก Epoch 1์์๋ ์ฝ 110G ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํ์ง๋ง ๋ ๋ฒ์งธ Epoch๋ฅผ ํ๋ จํ ๋ ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋ชจ๋ ์๋ชจ๋๊ณ ์์คํ
์ด ๋ฐ์ดํฐ๋ก๋๋ฅผ ์ฃฝ์ผ ๊ฒ์
๋๋ค..... ์์ธ์ง ๋ชฐ๋ผ....
๋์๊ฒ ๋ฌธ์ ๋ ์ด๋ฏธ numpy ๋ฐฐ์ด์ ๋ฐ์ดํฐ ๋ก๋ __getitem__
์์ ํ ์น ํ
์๋ก ๋ณํํ๊ณ ์๋ค๋ ๊ฒ์
๋๋ค.
Numpy ๋ฐฐ์ด์ ๋ชจ๋ธ๋ก ์ ์ก๋๊ธฐ ์ง์ ์ ํธ๋ ์ด๋ ๋ฃจํ์์ ํ ์น ํ ์๋ก๋ง ๋ณํ๋์ด์ผ ํฉ๋๋ค. ๊ทธ๋ ์ง ์์ผ๋ฉด ํ ์๊ฐ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ฒ์๋ฅผ ๋ฒ์ด๋๊ฒ ๋ง๋ญ๋๋ค.
watch -n .3 df -h
๋ช
๋ น์ ์คํํ์ฌ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ชจ๋ํฐ๋งํ ์ ์์ต๋๋ค.
๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ /dev/shm
ํ์ ํด๋นํฉ๋๋ค.
์ฌ์ฉ๋ ์์ ๊ฐ Epoch ํ์ ์ฆ๊ฐํ์ง ์์์ผ ํฉ๋๋ค.
๋๋ ๊ฐ์ ๋ฌธ์ ๊ฐ์๋ค
์ด ๋ฒ๊ทธ๋ pytorch 1.4.0์์ ํด๊ฒฐ๋์ง ์์์ต๋๋ค.
๋๋ ๊ฐ์ ๋ฌธ์ ๊ฐ์๋ค
์ ๋ ๊ฐ์ ๋ฌธ์ ์ ์ง๋ฉดํด ์์ต๋๋ค.
1) ๋ถํ์ํ ๋ณ์๋ฅผ ๋ชจ๋ ์ญ์
2) ๋ชฉ๋ก ๋์ numpy ๋ฐฐ์ด ์ฌ์ฉ
3) gc.collect() ์ฌ์ฉ
@annukkaa ๋ฐ ๊ธฐํ: ๋ฌธ์์ด ๋ชฉ๋ก์ ๋ง์ ๊ฐ์ฒด๋ก ์ ์ฅํ๊ธฐ ๋๋ฌธ์ np.array(list_of_paths)
๋ฅผ ์ฌ์ฉํ๋ ๊ฒ๋ง์ผ๋ก๋ ์ถฉ๋ถํ์ง ์์ต๋๋ค. np.array(list_of_paths).astype(np.string_)
๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐฐ์ด์ ์ ์ฌ๊ฐํ ๋ฐ์ดํธ ๋ฐฐ์ด๋ก ์บ์คํ
ํฉ๋๋ค(๊ทธ๋ฆฌ๊ณ ์ค์ ๋ก ๋ฌธ์์ด์ ์ฌ์ฉํ ๋ ๋ฐ์ดํธ์์ str๋ก ๋ณํํด์ผ ํจ). ๋์์ด ๋ ๊ฒ์
๋๋ค. ๋ํ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋์ ๊ฐ(์: 100GB)์ผ๋ก ์ค์ ํฉ๋๋ค.
์ด ์ค๋ ๋์์ ๋ช
์์ ์ผ๋ก ์ธ๊ธ๋ ๊ฒ์ ๋ณธ ์ ์ด ์์ผ๋ฏ๋ก ๋ด ์๋ฃจ์
์ ๊ณต์ ํ ๊ฒ์ด๋ผ๊ณ ์๊ฐํ์ต๋๋ค.
์ ๊ฒฝ์ฐ์๋ ๋ชจ๋ ๋ฐ๋ณต์ ์ก์ธ์คํ์ฌ ๋น ๋ฅด๊ฒ CPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ณ ๊ฐ์ํจ ๋ฐ์ดํฐ ์ธํธ์ ์ฌ์ฉ์ ์ง์ ํด๋์ค ๊ฐ์ฒด์ ๋ฌธ์์ด ๋ชฉ๋ก์ด ์์์ต๋๋ค.
๊ณต์ ์ํ๋ฅผ ์ฒ๋ฆฌํ๋ ๋ค์ค ์ฒ๋ฆฌ ๊ด๋ฆฌ์ ๊ฐ์ฒด ๋ฅผ ์ฌ์ฉํ์ฌ ํด๋์ค์ ๋ชฉ๋ก์ ๋ํํ์ฌ ๋ฉ๋ชจ๋ฆฌ ๋์๋ฅผ ์ ๊ฑฐํ ์ ์์์ต๋๋ค.
์ต์ํ์ ์์ ์ ์ฐ๊ฒฐํ๋ฉด ์ฝ๋๋ ๋ค์๊ณผ ๊ฐ์ ๊ฒ์ ๋๋ค.
from torch.utils.data import Dataset, DataLoader
import torch
from multiprocessing import Manager
class DataIter(Dataset):
def __init__(self):
manager = Manager()
self.data = manager.list([x for x in range(24000000)])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
return torch.tensor(data)
train_data = DataIter()
train_loader = DataLoader(train_data, batch_size=300,
shuffle=True,
drop_last=True,
pin_memory=False,
num_workers=18)
for i, item in enumerate(train_loader):
if i % 1000 == 0:
print(i)
๊ฐ์ฒด๊ฐ ํผํด๋๊ธฐ ๋๋ฌธ์ ์ฝ๊ฐ์ ์ค๋ฒํค๋๊ฐ ์์ง๋ง ๋ฉ๋ชจ๋ฆฌ๊ฐ ํญ๋ฐํ๋ ๊ฒ๋ณด๋ค ์ข์ ๋์์ ๋๋ค.
์ด๊ฒ ๋ค ๊ณ ์ณ์ง๊น์????
๋ฌธ์ ๊ฐ ์์ง ์ด๋ ค์๋ ๊ฒ ๊ฐ์ต๋๋ค.
ndarrays๋ฅผ ์ฌ์ฉํด๋ ๋์์ด ๋์ง ์์ต๋๋ค. ์์
์๊ฐ 0์ธ ์ํ์์ CPU RAM์ ์ฝ 4๋ฐฐ ์ฌ๋ฆฝ๋๋ค.
๋ธ์ ์๋ํ์ง๋ง ์ ์๋ฏธํ ๊ฐ์ ์ ์์์ต๋๋ค.
์๋ ๋ชจ๋,
๋๋ ์ด๊ฒ์ ๋ํ ํด๊ฒฐ์ฑ ์ ์๋ํ๊ณ ์ด๊ฒ์ ์ ๋์ ์ธ ์๋ฆ๋ค์์ฒ๋ผ ์๋ํฉ๋๋ค.
๋๋ฅผ ์ํด ๋ก์ปฌ์ ์ ์ฅ๋ numpy ๋ฐฐ์ด๋ก imagenet ๋ฐ์ดํฐ๋ฅผ ์ ์ฅํ๊ณ ์์ต๋๋ค.
๋ด ์ฌ์ฉ์ ์ ์ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ์์ฑํ์ต๋๋ค ---
`์์
ํ ์น
Torch.utils์์ ๋ฐ์ดํฐ ๊ฐ์ ธ์ค๊ธฐ
numpy๋ฅผ np๋ก ๊ฐ์ ธ์ค๊ธฐ
ํด๋์ค DataSetBuilder(data.Dataset):
"""TinyImagenet ๋ฐ์ดํฐ ์ธํธ."""
def __init__(self, rootpath, train=True, transform=None):
"""
Args:
rootpath: Path to the pytorch file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.path = rootpath
self.transform = transform
self.train = train
# Load input data
if self.train:
self.X_ = np.load(self.path +'x_train.npy')
else:
self.X_ = np.load(self.path +'x_test.npy')
# Load target data
if self.train:
self.y_ = np.load(self.path +'y_train.npy')
else:
self.y_ = np.load(self.path +'y_test.npy')
def __len__(self):
if self.train:
dataFile = self.path + 'x_train.npy'
else:
dataFile = self.path + 'x_test.npy'
data = np.load(dataFile)
return data.shape[0]
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
X = self.X_[idx, :, :, :]
y = self.y_[idx]
if self.transform is not None:
X = self.transform(X)
return X, torch.from_numpy(y).type(torch.LongTensor)`
__getitem__์ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํ๋ ๋์ ๊ฐ์ฒด๋ฅผ ๋น๋ํ ๋ ๋ก๋ํ๊ณ ์์ต๋๋ค. ์ฆ, ๋งค๋ฒ ๋์ผํ numpy ๋ฐฐ์ด์ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋ํ์ง ์๊ณ ๊ฐ์ฒด ์์ฑ ์ ํ ๋ฒ์ ๋ก๋ํฉ๋๋ค.
๋์์ด ๋์๊ธฐ๋ฅผ ๋ฐ๋๋๋ค!
์ด๊ฒ์ด ํจ๊ณผ๊ฐ ์๋ค๋ฉด ๋๊ธ์ ๋จ๊ฒจ์ฃผ์ธ์... :-)
์๋
ํ์ธ์ @varinder-singh๋,
ํด๊ฒฐ์ฑ
์ ์ฐพ์ผ์
จ๋ค๋ ๋คํ์
๋๋ค. ์ด์ ์ @bfreskura ๊ฐ ์ ๊ณตํ numpy ์์ ์ ์ด๊ฒ์ด ์ด๋ป๊ฒ ๋ค๋ฅธ์ง ๋ชจ๋ฅด๊ฒ ์ต๋๋ค. ๊ทํ์ __getitem__
๋ ๋ํ numpy ๋ฐฐ์ด์์ ๋ฐ์ดํฐ๋ฅผ ์ฌ๋ผ์ด์คํฉ๋๋ค.
์ฝ๋๋ฅผ ์๋ชป ์ฝ๊ณ ์๋ ๊ฒ ๊ฐ์ต๋๋ค. ๋ฉ๋ชจ๋ฆฌ ์๋น์ ๋ค๋ฅธ ์ํฅ์ ๋ฏธ์น๋ ์ด์ ๋ฅผ ์ค๋ช
ํด ์ฃผ์๊ฒ ์ต๋๊น?
ํ์ฌ ํ๋ก์ ํธ์์ ์ด ๋ฌธ์ ๊ฐ ๋ฐ์ํ๊ณ ์ด ์ค๋ ๋๋ฅผ ์ฝ์ ํ ๋ด ์๊ฐ์ ์ถ๊ฐํ๊ณ ๋ค์ ๋ค์ํ ์๋ฃจ์ ์ ์ ๊ณตํ๋ ๊ฒ์ด ์ ์ฉํ ์ ์๋ค๊ณ ์๊ฐํฉ๋๋ค.
๋จผ์ ์ฒซ ๋ฒ์งธ ๊ฒ๋ค:
1) ์ฌ๊ธฐ์์ ๊ด์ฐฐํ ์ ์๋ ๋ชจ๋ ๊ฒ์ ๊ณ ๋ คํ๋ฉด @mprostock ์ ์ง๋จ์ด ์ ํํฉ๋๋ค. ๋น์ ์ ์์
๋๋ถ์ ํผ์์ ๋
์ ํ๋๋ฐ ๋ง์ ์๊ฐ์ ์ ์ฝํ ์ ์์์ต๋๋ค.
2) ๋ฌผ๋ก @sumith ์ ์๋ต๋ ์ ํํ์ง๋ง @mprostock ์ ์ดํ ๊ฐ์ฒด ๋ฐฐ์ด ๊ฒ์๋ฌผ์์ ์ธ๊ธํ ์ด์ ๋ก ์ธํด ์ด ๊ฒฝ์ฐ์๋ ์ ์ฉ๋์ง ์์ต๋๋ค.
์ด๊ฒ์ pyTorch ๋ฌธ์ ๊ฐ ์๋๋๋ค. ์ด๊ฒ์ ํ์ด์ฌ ๋ฌธ์ ์ด๋ฏ๋ก ๊ฑฐ๊ธฐ์์ ํด๊ฒฐํด์ผ ํฉ๋๋ค. ๊ทธ๋ฌ๋ ๋ฌธ์ ๋ Python์ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ์ ํ์์ ์ธ ๋ถ๋ถ์ธ ์ฐธ์กฐ ์นด์ดํ ์ผ๋ก ์ธํด ๋ฐ์ํ๊ธฐ ๋๋ฌธ์ ์ด๋ ๊ณง ๋ฐ์ํ์ง ์์ ์ ์์ต๋๋ค. ์์์ ์ ์ํ ํด๊ฒฐ ๋ฐฉ๋ฒ ์ค ์ผ๋ถ๋ ํฅ๋ฏธ๋กญ์ง ๋ง ์ ๊ทธ๋ฌํ ๊ธธ์ด๋ก ์ด๋ํฉ๋๊น? ์์ ์ด ํ์ผ ์ด๋ฆ๊ณผ ๊ฐ์ ์ฌ๋ฌ ๊ฐ๋ณ ๊ธธ์ด ์ํ์ค์ ๊ณต๋์ผ๋ก ์ก์ธ์คํ๋ ๊ฒ์ด๋ผ๊ณ ๊ฐ์ ํ๋ฉด ์๋ก์ด ๊ฒ์ ๋ฐ๋ช ํ ํ์๊ฐ ์์ต๋๋ค. numpy๋ฅผ ์ฌ์ฉํ์ฌ ์ํ์ค๋ฅผ ์์ถํ๊ณ ๊ฐ์ ์กฐํ๋ฅผ ์ํํ๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค. ๋ด ๋ง์ ์ดํดํ๋ ค๋ฉด ์ด ์ค๋ ๋์์ ๋ ผ์๋ ๋ฌธ์ ๋ฅผ ์์ ํ ํผํ๋ ์๋ ์ฝ๋๋ฅผ ์ฐธ์กฐํ์ญ์์ค.
@mprostock ๋ฐ @smolendawid ๋ฌธ์์ด์ ๋ณธ์ง์ ์ผ๋ก ์ ์ ์ํ์ค์ด๋ฉฐ numpy์์ ์ฝ๊ฒ ์ฒ๋ฆฌํ ์ ์๋ ์ ํ์
๋๋ค. ์๋ ์์ ๋ ์ฌ๋ฌ ๋ฐ์ดํฐ ๋ก๋ ๊ฐ์ ๋ฌธ์์ด ๋ชฉ๋ก(์: ์ด๋ฏธ์ง์ ํ์ผ ์ด๋ฆ)์ ๊ณต์ ํ๋๋ก ์กฐ์ ๋์์ต๋๋ค.
@marrrcin ๋ชจ๋ฒ ์ฌ๋ก๋ฅผ ์์ฒญํ์
จ์ต๋๋ค. ์ด๊ฒ์ ๊ฐ๋ ฅํ๋ฉฐ ๊ฐ๋ณ ๊ธธ์ด ์ํ์ค์ ๋ชจ๋ ๋ชฉ๋ก์์ ์๋ํฉ๋๋ค. ํ์ฌ ํ๋ก์ ํธ์์ ๊ฐ ์ฐจ์์ ๊ธธ์ด๊ฐ ๊ฐ๋ณ์ ์ธ ๋ค์ฐจ์ ๋ฐ์ดํฐ์ ๋ํด ์ด๊ฒ์ ์ฝ๊ฐ ๋ ์ ๊ตํ ๋ณํ์ ์ฌ์ฉํฉ๋๋ค.
@SsnL ์ด๊ฒ์ ๋ฉ์ง Python 3.8 ๊ตฌ์ฑ์ ์ฌ์ฉํ์ง ์๊ณ /issues/ 20433 ์์ @zhiweifang๊ณผ ๋
ผ์ํ ๋ฌธ์ ๋ฅผ ์์์ ์ผ๋ก ํด๊ฒฐํฉ๋๋ค.
import numpy as np
import torch
from typing import Union
# --- UTILITY FUNCTIONS ---
def string_to_sequence(s: str, dtype=np.int32) -> np.ndarray:
return np.array([ord(c) for c in s], dtype=dtype)
def sequence_to_string(seq: np.ndarray) -> str:
return ''.join([chr(c) for c in seq])
def pack_sequences(seqs: Union[np.ndarray, list]) -> (np.ndarray, np.ndarray):
values = np.concatenate(seqs, axis=0)
offsets = np.cumsum([len(s) for s in seqs])
return values, offsets
def unpack_sequence(values: np.ndarray, offsets: np.ndarray, index: int) -> np.ndarray:
off1 = offsets[index]
if index > 0:
off0 = offsets[index - 1]
elif index == 0:
off0 = 0
else:
raise ValueError(index)
return values[off0:off1]
# --- OUR DATASET CODE STARTS HERE ---
class MyDataset(torch.utils.data.Dataset):
def __init__(self):
strings = [
'I like', # You can use np.int8 for ASCII strings.
'chocolate',
'ๆๅๆฌข', # If you use anything that is not standard ASCII,
'ๅทงๅ
ๅ', # need to use np.int16, or even np.int32.
]
# Convert each string to sequence of codepoints (integer),
# and then pack them into a numpy array.
seqs = [string_to_sequence(s) for s in strings]
self.strings_v, self.strings_o = pack_sequences(seqs)
def __len__(self): return 4
def __getitem__(self, i):
# Use indirect lookup to fetch the i-th sequence. This only uses integer numpy
# array lookups, which avoids that the objects are subsequently replicated by
# child processes.
seq = unpack_sequence(self.strings_v, self.strings_o, i)
string = sequence_to_string(seq)
# ACTION NEEDED: You probably do not want to return the string itself ;-).
return string
m = MyDataset()
for i in range(len(m)):
print(i, '=', m[i])
# Output
# -------
# 0 = I like
# 1 = chocolate
# 2 = ๆๅๆฌข
# 3 = ๅทงๅ
ๅ
๋๋ ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์ ์์๊ณ ๋ด 2์ผํธ๋ฅผ ๊ณต์ ํ๊ณ ์ถ์ต๋๋ค. ๋๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ฌธ์์ด์ด ๋ฌธ์ ๋ผ๊ณ @harpone ๊ณผ ๋ค๋ฅธ ์ฌ๋๋ค์ด ์ง์ ํ ์์ด๋์ด๋ฅผ ๋ฐ๋์ต๋๋ค. ๋ด Dataset ํด๋์ค์ 2๊ฐ์ ๋ฌธ์ ๊ฐ ์๋ ์ธ์๊ฐ ์์ต๋๋ค.
๋ฉ๋ชจ๋ฆฌ ๋์๋ฅผ ๋ง์ผ๋ ค๋ฉด 1๊ณผ 2๋ฅผ ๋ชจ๋ ์์ ํด์ผ ํ์ต๋๋ค. 1์ ๊ฒฝ์ฐ ๋ด ๋ฌธ์์ด์ ์ค์ ๋ก ์ฌ์ ์ numpy ๋ฒกํฐ์ ์ก์ธ์คํ๊ธฐ ์ํ ํด์์ด๋ฏ๋ก ๊ณ ์ ํฌ๊ธฐ ์ฌ์ ์ด ์์ผ๋ฏ๋ก ๋ชจ๋ ๋ฌธ์์ด์ ์ ์๋ก ๋ณํํ์ต๋๋ค.
2์ ๊ฒฝ์ฐ ์ ์ ํค๋ฅผ ์ฌ์ฉํ๋๋ก ์ฌ์ ์ ๋ณํํ์ง๋ง ๋ฉ๋ชจ๋ฆฌ ๋์๋ ์ฌ์ ํ ์ง์๋์์ต๋๋ค. ์ค์ ๋ก ์๋ํ ๊ฒ์ ์ฌ์ ์ Dataset ํด๋์ค์ ์ ํ ์ ๋ฌํ์ง ์๊ณ __getitem___์์ interger ํค๋ฅผ ๋ฐํํ๊ณ Pytorch ํ ์๋ก ์ฌ์ ์ธ๋ฑ์ฑ/์ด๋/๋ด ๊ธฐ์ฐจ ๋ฃจํ์์ GPU๋ก ์น๊ฒฉํ๋ ๊ฒ์ด์์ต๋๋ค.
๋ฐ์ดํฐ ๋ก๋ ํ๋ก์ธ์ค๊ฐ ๋งค ์ํฌํฌ(epoch)๋ง๋ค ์ค์ค๋ก๋ฅผ ๋ค์ ์ด๊ธฐํํ๊ณ ๋์ถ๋ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ชจ๋ ์ ๋ฆฌํ๋๋ก ํ๋ ๋ฐฉ๋ฒ์ด ์์ต๋๊น?
@Pozimek ๊ทธ๋ค์ ์ด๋ฏธ ๋ชจ๋ ์๋๋ฅผ ๋ค์ ์ด๊ธฐํํฉ๋๋ค.
๊ทธ๋ ๋ค๋ฉด ์ง๊ธ ๊ฐ์ฅ ์ข์ ๋ฐฉ๋ฒ์ ๋ฌด์์ ๋๊น?
@wangchust : @bashimao ๊ฐ ์ ์ํ ์๋ฃจ์ ์ ์ ๋นํ ํฐ(2,500๋ง ๊ฐ ์ด์์ ํ ์คํธ ์ํ์ค) ๋ฐ์ดํฐ ์ธํธ์์๋ ์๋ฆ๋ต๊ฒ ์๋ํ์ต๋๋ค.
@wangchust : @bashimao ๊ฐ ์ ์ํ ์๋ฃจ์ ์ ์ ๋นํ ํฐ(2,500๋ง ๊ฐ ์ด์์ ํ ์คํธ ์ํ์ค) ๋ฐ์ดํฐ ์ธํธ์์๋ ์๋ฆ๋ต๊ฒ ์๋ํ์ต๋๋ค.
์ ๋ ์. @bashimao ์ ์๋ฃจ์ ์ ๋งค์ฐ ์ ์๋ํฉ๋๋ค.
์๋ ํ์ธ์ ์ฌ๋ฌ๋ถ, ๋ค์ ์์ต๋๋ค. ๋ฐ์ดํฐ ๋ก๋ ์์ ์๊ฐ ์ฃผ ํ๋ก์ธ์ค์์ ๋ถ๊ธฐํ ๋ "OverflowError: 4GiB๋ณด๋ค ํฐ ๋ฐ์ดํธ์ด ๊ฐ์ฒด๋ฅผ ์ง๋ ฌํํ ์ ์์ต๋๋ค"๋ฅผ ์ถฉ์กฑํ๋ ์ฌ๋์ด ์์ต๋๊น?
์๋ ํ์ธ์ ์ฌ๋ฌ๋ถ, ๋ค์ ์์ต๋๋ค. ๋ฐ์ดํฐ ๋ก๋ ์์ ์๊ฐ ์ฃผ ํ๋ก์ธ์ค์์ ๋ถ๊ธฐํ ๋ "OverflowError: 4GiB๋ณด๋ค ํฐ ๋ฐ์ดํธ์ด ๊ฐ์ฒด๋ฅผ ์ง๋ ฌํํ ์ ์์ต๋๋ค"๋ฅผ ์ถฉ์กฑํ๋ ์ฌ๋์ด ์์ต๋๊น?
@wangchust ์ง๋ ฌํ ์ค์ด๋ผ๋ฉด ์๋ง๋ ๋ญ๊ฐ ์๋ชปํ๊ณ ์์ ๊ฒ์ ๋๋ค. ๊ฐ ํ๋ก์ธ์ค๋ 4๊ฐ ๋๋ ์๋ฌด๋ฆฌ ํฐ ๊ฐ์ฒด๊ฐ ๊ธฐ๊ฐ๋ฐ์ดํธ ๋จ์๋ผ๋ ์ญ์ง๋ ฌํํ๊ณ ์ง๋ ฌํ๋ ๊ฐ์ฒด๋ฅผ ์ฌ๊ตฌ์ฑํฉ๋๋ค. ๋ฐ๋ผ์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ณต์ ํ๊ณ ๋ณ๋ ฌ ํ๋ก์ธ์ค๊ฐ ๋ง์ผ๋ฉด ๊ฒฐ๊ตญ ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋ถ์กฑํด์ง๋๋ค. ์ด ์ค๋ ๋์์ ๋์ ๋ค๋ฅธ ์ฌ๋๋ค์ด ์ ์ํ ์กฐ์น์ ์์ ์ ๋ฉ๋ชจ๋ฆฌ ๋ณต์ ๋ฅผ ํผํ๋ ๊ฒ์ ๋๋ค. ๋ด ์ฒซ ๋ฒ์งธ ๋ฌธ์ฅ์์ ๋งํ๋ฏ์ด, ๋๋ ๋น์ ์ด ์๋ง๋ ๊ฝค ๊ธฐ๋ณธ์ ์ธ ์์ค์์ ๋ญ๊ฐ ์๋ชปํ๋ค๊ณ ๋ฏฟ์ต๋๋ค.
์ฌ์ฉ์ ์ง์ ํ ์ ์ง์ ๋ฌธ์์ด ๋ฐฐ์ด์ด https://gist.github.com/vadimkantorov/86c3a46bf25bed3ad45d043ae86fff57 ์ ๋์์ด ๋๋ ๊ฒ ๊ฐ์ต๋๋ค.
import torch
class TensorBackedImmutableStringArray:
def __init__(self, strings, encoding = 'utf-8'):
encoded = [torch.ByteTensor(torch.ByteStorage.from_buffer(s.encode(encoding))) for s in strings]
self.cumlen = torch.cat((torch.zeros(1, dtype = torch.int64), torch.as_tensor(list(map(len, encoded)), dtype = torch.int64).cumsum(dim = 0)))
self.data = torch.cat(encoded)
self.encoding = encoding
def __getitem__(self, i):
return bytes(self.data[self.cumlen[i] : self.cumlen[i + 1]]).decode(self.encoding)
def __len__(self):
return len(self.cumlen) - 1
def __list__(self):
return [self[i] for i in range(len(self))]
์๋ง๋ ์ด์ ๊ฐ์ sth๋ ํต์ฌ PyTorch์ ํฌํจํ ๊ฐ์น๊ฐ ์์ต๋๋ค.
๋๊ตฌ๋ ์ง ์ฌ์ ์ด ์๋ํ๊ณ ๋์ถ๋์ง ์๋๋ก ์ด์ด ์์ต๋๊น?
์์์ ์ด ๊ฒ์๋ฌผ ์ ๋ณด์์ง๋ง ํด๋น ์๊ฒฌ์์ ์ ์ํ๋ ๊ฒ์ฒ๋ผ ์ธ๋ถ์์ ์์
์ ์ํํ๋ ๋์ ์์
์ ๋ด๋ถ์ ์ผ๋ถ ์ ํ์ ํด์ ํ
์ด๋ธ์ ์ก์ธ์คํ๊ณ ์ถ์ต๋๋ค.
๋ค์ ์ค ํ๋๋ฅผ ๊ณ ๋ คํ๊ณ ์์ต๋๋ค.
๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ ๊ฐ์ฅ ์ ๋งํ๊ณ python ์ต์ ์ ๊ณ ์ ํ ๊ฒ์ฒ๋ผ ๋ณด์ ๋๋ค. dict๋ฅผ ์ฌ์ฉํ๋ ์ด์ ๊ฐ ๊ถ๊ธํฉ๋๋ค. ์ฌ๊ธฐ์ ์ผ๋ฐ์ ์ธ ํจํด์ ํญ๋ชฉ ๋ชฉ๋ก(์ผ๋ฐ์ ์ผ๋ก ๋ฌธ์์ด)์ ๊ฐ๊ณ ์์ธ์ ์์ฑํ๋ ๊ฒ์ ๋๋ค.
๋์๊ฒ ๊ทธ๊ฒ์ ์๋ dicts ๋ชฉ๋ก์ด์์ต๋๋ค (์์ ๋ฉํ ๋ฐ์ดํฐ ๋ชฉ๋ก, ๋ชจ๋ ์๋ dict์์ต๋๋ค)
์์์ด์. ์ผ๋ฐ์ ์ผ๋ก dicts๋ ๋ฉ๋ชจ๋ฆฌ ์ก์ธ์ค ํจํด์ด ์์ฐจ์ ์ด์ง ์๊ธฐ ๋๋ฌธ์ ๋ ์ด๋ ต๊ฒ ๋ง๋ญ๋๋ค. Fork-safe ๋ฐ์ดํฐ ๊ตฌ์กฐ ์ง์์ ์ถ๊ฐํ ์๊ฐ์ ๋๋ค.
๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ ๊ฐ์ฅ ์ ๋งํ๊ณ python ์ต์ ์ ๊ณ ์ ํ ๊ฒ์ฒ๋ผ ๋ณด์ ๋๋ค. dict๋ฅผ ์ฌ์ฉํ๋ ์ด์ ๊ฐ ๊ถ๊ธํฉ๋๋ค. ์ฌ๊ธฐ์ ์ผ๋ฐ์ ์ธ ํจํด์ ํญ๋ชฉ ๋ชฉ๋ก(์ผ๋ฐ์ ์ผ๋ก ๋ฌธ์์ด)์ ๊ฐ๊ณ ์์ธ์ ์์ฑํ๋ ๊ฒ์ ๋๋ค.
@VitalyFedyunin ํ ์ฃผ์
์ ๊ฐ์ฌํฉ๋๋ค. ๊ณต์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋จผ์ ์๋ํด ๋ณผ ์ ์์ต๋๋ค.
dict์ ์ด์ ๋ ๋ฐ๋ก ์ง๊ธ ๋ฐ์ดํฐ ์์ฑ ๋จ๊ณ์์ ์์ ์ํ๋ง ๊ธฐ๋ฅ์ ๋ํ ์์์ O(1) ์กฐํ์
๋๋ค. ๋ณด๋ค ๊ตฌ์ฒด์ ์ผ๋ก, dict๊ฐ user_id์ ๋ํด ์
๋ ฅ๋๊ณ ๊ฐ์ด ํด๋น ์ฌ์ฉ์์ ๊ด๋ จ๋ ๊ธ์ ์ ์ธ ์์ ๋ชฉ๋ก์ธ "ํธ๋ฆฌํ๋ ๋ง์ด๋"์
๋๋ค. ์๋ฅผ ๋ณด๋ ค๋ฉด ์ฌ๊ธฐ ๋ฅผ ์ฐธ์กฐํ์ญ์์ค.
@marrrcin ๋๋ ์ผ๋ฐ์ ์ผ๋ก
__getitem__
์tensor
๋ก ์ ๋ ฅ ๋ฐ์ดํฐ(์ด๋ฏธ์ง ๋๋ ์ ํธ ๋ฑ)๋ง ์บ์คํธํฉ๋๋ค. ๋ด ๋ ์ด๋ธ ๋ฑ์ ์ผ๋ฐ์ ์ผ๋ก ๋ชฉ๋ก์ผ๋ก ๋ฐํ๋ฉ๋๋ค. ์ด๋ค ์ข ๋ฅ์ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ๊ณ ์๋์ง ๋๋ ํน๋ณํ ์ข ๋ฅ์ ํจ๋ฉ์ ์ฌ์ฉํ๊ณ ์๋์ง ์ ๋ชจ๋ฅด๊ฒ ์ง๋ง ์ผ๋ฐ์ ์ผ๋ก__getitem__
torchvision.transforms
#$ ๋ฅผ ์ฌ์ฉํฉ๋๋ค. ๊ทธ๋งํ ๊ฐ์น๊ฐ ์๊ธฐ ๋๋ฌธ์ ์ฌ์ฉ์ ์ ์collate_fn
๋ฅผ ๊ตฌํํ๋ ๊ฒฝ์ฐ๋ ๊ฑฐ์ ์์ต๋๋ค.์๊ฐ: ๋๋ ๋ฉ๋ชจ๋ฆฌ ๋์๋ผ๊ณ ์๊ฐํ๋ ๊ฒ์ ๊ฒฝํํ๊ณ ์์๊ธฐ ๋๋ฌธ์ ์๋ ์ฌ๊ธฐ์ ๊ฒ์ํ์ต๋๋ค. ๋งค ์ํฌํฌ๋ง๋ค ๋ถํ์ํ ๋ฐ์ดํฐ์ ๋งค๋ฌ๋ฆฌ๊ณ ์์๊ณ , ์ ์ ์ฅ์์๋ ์ ๋ง ๋ฏธ๋ฌํ ๋ณ์ ๊ด๋ฆฌ์์ ๋ ๋์ ์ฆ์์ด ๋ํ๋ฌ์ต๋๋ค. ๋ฌด์จ ์ผ์ด ์ผ์ด๋๊ณ ์๋์ง ์ ํํ ํ์ ํ๋ ๋ฐ ์๊ฐ์ด ๊ฑธ๋ ธ์ต๋๋ค.
@AudreyBard ๊ฐ์ฌํฉ๋๋ค. ์ด๊ฒ์ ๋์์ด๋์๊ณ ๋ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ์ต๋๋ค.
๋ด๊ฐ ๊ถ๊ธํ ์ ์ (1) ์ ํ์ด ๋ฉ๋ชจ๋ฆฌ ์๋น์ ๋ง์ ์ํฅ์ ๋ฏธ์น๋ ์ด์ ์ (2) ์ ์ฒด ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ํ๋ก์ธ์ค ์ * ๋ฐ์ดํฐ ์์ฑ ํฌ๊ธฐ๋ณด๋ค ํจ์ฌ ๋ง์ ๊ฒ์ฒ๋ผ ๋ณด์ด๋ ์ด์ ์ ๋๋ค.
@bfreskura ์ ์์์ self.data
์ ํฌ๊ธฐ๋ 24e7 ์ ์๋ก ๋๋ต 1.83GB์
๋๋ค. ์คํฌ๋ฆฝํธ๋ฅผ ๋น ๋ฅด๊ฒ ์คํํ ์ ์๋๋ก 24e5๋ก ๋ฎ์ถ๋ฉด ๋ฐ์ดํฐ ๊ฐ์ฒด์ ํฌ๊ธฐ๋ ๋๋ต 18.92MB์
๋๋ค.
Python ๋ชฉ๋ก์ ๊ฒฝ์ฐ shuffle=False๋ก ์ค์ ํ๋ฉด ํ๋ก์ธ์ค๊ฐ 298.17MB๋ฅผ ์๋นํ๋ ๊ฒ์ผ๋ก ์ธก์ ๋ฉ๋๋ค. ๊ทธ๋ฐ ๋ค์ shuffle=True๋ก ์ค์ ํ๊ณ ํ๋ก์ธ์ค๊ฐ 1.44GB๋ฅผ ์๋นํ๋ ๊ฒ์ผ๋ก ์ธก์ ํฉ๋๋ค.
๋ฐ๋ผ์ 18๋ช ์ด ๋๋ ์์ ์ + 1๊ฐ์ ์ฃผ์ ์์ ํ๋ก์ธ์ค, ๋ชจ๋ ๋ฐ์ดํฐ๊ฐ ๋ชจ๋ ํ๋ก์ธ์ค์ ๋ณต์ฌ๋๋๋ผ๋ ์ต๋ 359.48MB์ ์ถ๊ฐ RAM๋ง ์์ด์ผ ํฉ๋๋ค. shuffle=True์ผ ๋ ๊ทธ ์์ ๊ฑฐ์ 4๋ฐฐ๊ฐ ๋๋ ๊ฒฝ์ฐ๋ ์ด๋ป๊ฒ ๋ฉ๋๊น? ์์ฐจ ๋ ์์ ๋ฉ๋ชจ๋ฆฌ ์ก์ธ์ค ๋ฐ ๊ฒฐ๊ณผ์ ์ธ ํ์ด์ง ์ค๋ฅ์ ๊ด๋ จ์ด ์์ด์ผ ํ๋ค๊ณ ์๊ฐํ์ง๋ง ์ฌ๊ธฐ์์ ๋ฌด์จ ์ผ์ด ์ผ์ด๋๊ณ ์๋์ง ๋ ์ ํํ๊ฒ ์ค๋ช ํ ์ ์๋ ์ฌ๋์ด ์๋์ง ๊ถ๊ธํฉ๋๋ค.
@bfreskura ์คํฌ๋ฆฝํธ์ ๋ํ ๋ด ์์ ์ฌํญ(CLI ์คํ + ๋ฉ๋ชจ๋ฆฌ ์๋น ๋ณด๊ณ )์ ์ฐธ์กฐํ๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํ์ญ์์ค.
https://gist.github.com/Erotemic/3f017de31529dc64c1a54948f37da1d5
๋๋ค ์ก์ธ์ค๋ ํ์ด์ฌ์ด ๊ฐ์ฒด ์นด์ดํฐ๋ฅผ ๋ฉ๋ชจ๋ฆฌ์ ๋ค์ ์ฐ๋๋ก ํ์ฌ ๋ฉ๋ชจ๋ฆฌ ํ๋ ์์ ๊ธฐ๋ก ์ ๋ณต์ฌ๋ฅผ ์ ๋ฐํฉ๋๋ค. ์์ฐจ ์ก์ธ์ค๋ ๋ณ๊ฒฝ๋์ง ์์ ์นด์ดํฐ๋ฅผ ์์ฑํ์ง ์์์ผ๋ก์จ ์ ์ฌ์ ์ผ๋ก ์ต์ ํ๋ ์ ์์ต๋๋ค(GC ์ฃผ๊ธฐ์ ๋ฐ๋ผ ๋ค๋ฆ). ๋ํ ์์ ์ ์ *(๋ชจ๋ ๊ฐ์ฒด ํฌ๊ธฐ + ๊ฐ์ฒด ์ * ํ์ด์ฌ ๊ฐ์ฒด ํฌ์ธํฐ+์นด์ดํฐ ํฌ๊ธฐ)์ธ ์ต๋ ์ฌ์ฉ๋์ผ๋ก ์ถ์ ํ๋ ๊ฒ์ด ํจ์ฌ ๋ ์์ ํฉ๋๋ค(์ง๊ธ๊น์ง๋). ์ฐ๋ฆฌ๋ ํ์ฌ ์ ์ฒด ๋ฉ๋ชจ๋ฆฌ ๋ณต์ฌ๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํ ์๋ฃจ์ ์ ์ฐ๊ตฌํ๊ณ ์์ง๋ง ์๋นํ ์ฌ์ค๊ณ๊ฐ ํ์ํ๊ณ ์๊ฐ์ด ๊ฑธ๋ฆด ๊ฒ์ ๋๋ค.
@VitalyFedyunin ์ค๋ช ๊ฐ์ฌํฉ๋๋ค๋ง ์์ง ์ ์ดํด๊ฐ ๋์ง ์๋ ๊ฒ ๊ฐ์ต๋๋ค :์ค๋ง์ผ:
๋ชฉ๋ก ๋์ numpy ๋ฐฐ์ด, ์๋ฅผ ๋ค์ด np.string_
์ ํ์ ์ ์ฌ๊ฐํ numpy ๋ฐ์ดํธ ๋ฐฐ์ด์ ์ฌ์ฉํ์ฌ ์์ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ์ง๋ง ์ด์ webdataset(https:/ /github.com/tmbdev/webdataset/issues/24#issuecomment-709101119). ๋ถ๋ช
ํ shm์ด ๋ถ์กฑํ์ง ์์ง๋ง @tmbdev ๊ฐ webdataset ์ค๋ ๋์์ ์์ ์ง์ ํ๋ฏ์ด ๋ฌธ์ ๋ ๊ณต์ ๋ฉ๋ชจ๋ฆฌ ์ธ๊ทธ๋จผํธ์ _์ซ์_์ผ ์ ์์ต๋๋ค...
์ด ๋ฌธ์ ๋ฐ/๋๋ ๊ด๋ จ ์์ ํดํน์ ๋๋ฒ๊น
ํ๋ ๋ฐฉ๋ฒ์ ๋ํ ํ์ด ์์ต๋๊น? ๋๋ ipcs๋ฅผ ์๋ํ์ง๋ง ๊ทธ๊ฒ์ ๋์๊ฒ ์ ์ฉํ ๊ฒ์ ๋ณด์ฌ์ฃผ์ง ๋ชปํ์ต๋๋ค (๋ด ์๊ฐ์). lsof /dev/shm
๋ shm ๊ฐ์ฒด ๋ฐ ํฌ๊ธฐ์ ๋ํ ์ ๋ณด๋ฅผ ๋ณด์ฌ์ฃผ์ง๋ง ๊ทธ ์๋ฏธ๊ฐ ๋ฌด์์ธ์ง ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค...
์ ์๊ฒ proportional set size
(pss in psutil)์ ์ธก์ ํ๋ฉด ๋ฌธ์ ์ ํฌ๊ธฐ๋ฅผ ์ธก์ ํ๋ ๋ฐ ๋์์ด ๋์์ต๋๋ค. ์ฌ์ฉ์ ์ง์ StringArray ๋ฐ DictArray ํด๋์ค๋ก ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ์ต๋๋ค.
@wangchust : @bashimao ๊ฐ ์ ์ํ ์๋ฃจ์ ์ ์ ๋นํ ํฐ(2,500๋ง ๊ฐ ์ด์์ ํ ์คํธ ์ํ์ค) ๋ฐ์ดํฐ ์ธํธ์์๋ ์๋ฆ๋ต๊ฒ ์๋ํ์ต๋๋ค.
์ฃ์กํฉ๋๋ค. github ์ฌ์ฉ์ ๋ํด ๋๋ฝ๋ ๊ฒ์ด ์์ ์ ์์ง๋ง ์ด ์ค๋ ๋์์ @bashimao ์ ํด๊ฒฐ์ฑ ์ ์๊ณ ์ฃผ์๋ง ์์ต๋๋ค. ์๋ฌด๋ ๋์๊ฒ ๊ทธ๊ฒ์ ๊ฐ๋ฆฌํฌ ์ ์์ต๋๊น?
@ alexander-soare https://github.com/pytorch/pytorch/issues/13246#issuecomment -617140519
np.string_
๋ก ์บ์คํ
ํ๋ ๊ฒ์ด ํจ์ฌ ๊ฐ๋จํฉ๋๋ค( str
๋ฑ์ ์๋). ๊ฐ์ง๊ณ ์๋ค๊ณ
strings = ['hello', 'world']
๊ทธ๋ผ ํด
strings_byte = np.array(strings).astype(np.string_)
๊ทธ๋ฌ๋ฉด ๊ฒฐ๊ณผ๋ ๋จ์ผ ์ ์ฌ๊ฐํ ๋ฐ์ดํธ ๋ฐฐ์ด์ด ๋ฉ๋๋ค(dtype ์ฐธ๊ณ ).
array([b'hello', b'world'], dtype='|S5')
๊ทธ๋ฐ ๋ค์ str(strings_byte[0], encoding='utf-8')
์ ๊ฐ์ด ๋ฌธ์์ด์ ์ ํํ ๋ ๋ฌธ์์ด๋ก ๋ค์ ์ธ์ฝ๋ฉํด์ผ ํฉ๋๋ค.
์ด๊ฒ์ ์๋ํ์ง ์์ต๋๋ค.
strings_byte = np.array(strings).astype(str)
dtype์ ์ ์ํ์ญ์์ค.
array(['hello', 'world'], dtype='<U5')
๊ทธ๊ฒ์ ์ ์ฌ๊ฐํ ๋ฐ์ดํธ ๋ฐฐ์ด, ์ฆ ๋จ์ผ ๊ฐ์ฒด๊ฐ ์๋๋๋ค.
์ด ๋ฌธ์ ๊ฐ ์ง์๋๊ณ ๋์ ๋ด ๋๋ฃ๊ฐ ์ด ๋ฌธ์ ์ ๋ถ๋ชํ ํ์๋ฅผ ๊ณ ๋ คํ ๋ ์ด๊ฒ์ด ์์ธ์ธ์ง ์ฌ๋ถ๋ฅผ ๊ฒฐ์ ํ ์ ์๋ ๋ฐฉ๋ฒ์ด ์๋ค๋ฉด ๋์์ด ๋ ๊ฒ์ ๋๋ค. ์ด ์ค๋ ๋๋ฅผ ์ฒ ์ ํ ์ฝ์ ํ ๋ฌธ์ ๋ฅผ ์ํํ๊ธฐ ์ํ ์ข์ ์ ์์ด ์๋ ๊ฒ ๊ฐ์ต๋๋ค(https://github.com/pytorch/pytorch/issues/13246#issuecomment-436632186, https://github.com/pytorch/pytorch/issues). /13246#issuecomment-612396143), ์ผ๋ถ ํผ๋์ค๋ฌ์ด ๋์(https://github.com/pytorch/pytorch/issues/13246#issuecomment-708067670)๋ ์์ต๋๋ค.
shuffle=True
๋ฅผ ์ฌ์ฉํ๋ฉด https://github.com/pytorch/pytorch/issues/13246#issuecomment -708067670์ ์ค๋ช
๋ ๋๋ก ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ๋ ๋ง์์ง๋ ์ด์ ๋ ๋ฌด์์
๋๊น?๋์๊ฒ ๋ง๋ ์๋ฃจ์ - https://t.me/snakers4/2577
๋์๊ฒ ๋ง๋ ์๋ฃจ์ - https://t.me/snakers4/2577
์ด๊ฑฐ ์ข๋ค! https://gist.github.com/vadimkantorov/86c3a46bf25bed3ad45d043ae86fff57 ์์ ๋ด ๋ฐฉ๋ฒ์ ์ ์ผํ ์ฅ์ ์ DDP ํ๋ฆฌ๋ฏธํฐ๋ธ๋ฅผ ์ฌ์ฉํ์ฌ DDP ์์ ์ ๊ฐ์ ํ ์๋ก ๊ฐ๋ ์ฐฌ ๊ฐ์ฒด๋ฅผ ๊ณต์ ํ ์ ์๋ค๋ ๊ฒ์ ๋๋ค(์ฆ, ํ๋์ ์ค๋ ๋์์๋ง ๊ฑฐ๋ํ ๋ฐ์ดํฐ ์ธํธ ๊ฐ์ฒด๋ฅผ ์ฝ์ต๋๋ค. ๊ทธ๋ฐ ๋ค์ ํ ์๋ก ๊ฐ๋ ์ฐฌ ๋ฐ์ดํฐ ์ธํธ ๊ฐ์ฒด๋ฅผ ๋ค๋ฅธ DDP ์์์ ๋ถ์ฐ์ํต๋๋ค. ๊ฐ์ ๋ฐฉ์์ผ๋ก DDP ๋ง์คํฐ ์์ ์๋ DDP ์์์์ ํ ์๋ก ๊ฐ๋ ์ฐฌ ๋ฌธ์์ด ๋ฐฐ์ด์ ์์งํ ์ ์์ต๋๋ค.
์ด ๋ฒ๊ทธ์ ๋ ๋ค๋ฅธ ์ค์ ๋ฐ์: https://github.com/NVIDIA/NeMo/issues/1467
๊ฐ์ฅ ์ ์ฉํ ๋๊ธ
์ข ๋ ์กฐ์ฌํ ํ์ ๋์ถ์ด ๋ฐ์ํ๋ ์ ํํ ์๋๋ฆฌ์ค๋ฅผ ์ฐพ์์ต๋๋ค. ์๋ ์ฝ๋ ์์ ๋ฅผ ๊ณ ๋ คํ์ญ์์ค.
Python์ ํ์ค ์ ์ ๋ชฉ๋ก์ธ
self.data
๋ณ์๋ฅผ ์ฌ์ฉํ๋ฉด ๋ฐ์ดํฐ ๋์๊ฐ ๋ฐ์ํฉ๋๋ค. ๊ทธ๋ฌ๋self.data_np
๋ณ์๋ฅผ ์ฌ์ฉํ๋ฉด ๋์ผํ ๋ฐ์ดํฐ๋ฅผ ๊ฐ์ง๊ณ ์์ง๋ง Numpy ๋ฐฐ์ด ํํ๋ก ๋์ด ์์ผ๋ฉด ๋์๊ฐ ๋ฐ์ํ์ง ์์ต๋๋ค.๋ ๋ค๋ฅธ ๊ด์ฐฐ์
DataLoader
์shuffle=False
์ธ ๊ฒฝ์ฐ ๋์ถ์ด ํจ์ฌ ๋ ์ฌ๊ฐํ๋ค๋ ๊ฒ์ ๋๋ค.