resnet50์์ ์์ ํ ๋ชจ๋ธ์ด ์๋๋ฐ ๋ง์ง๋ง avgpool
& fc
๋ง ์ ๊ฑฐํ๋ฉด ๋ฉ๋๋ค.
ํ๋ จ ์ค์ ์
๋ ฅ ํฌ๊ธฐ๋ฅผ ๊ณ์ ๋ณ๊ฒฝํ๋ฉด ์๋๊ฐ ๋๋ ค์ง๋ ๊ฒ์ ๋ฐ๊ฒฌํ์ต๋๋ค.
์ต์ ์ฝ๋:
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import numpy as np
from torch.autograd import Variable
# ... remove avgpool & fc from resnet50 here
net = resnet50()
net.cuda()
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
cudnn.benchmark = True
for i in range(10):
h = np.random.randint(400,600)
w = np.random.randint(400,600)
# or fix h = w = 600
x = Variable(torch.randn(1,3,h,w)).cuda()
t1 = time.time()
y = net(x)
t2 = time.time()
print(t2-t1)
3.14512705803
0.11568403244
0.0255229473114
0.0228650569916
0.0235478878021
0.0225219726562
0.0436158180237
0.0222969055176
0.0223350524902
0.0227248668671
3.12573313713
0.670918941498
2.32590889931
2.3486700058
2.31507301331
0.593285083771
0.68169093132
2.34181690216
0.597991943359
1.74615192413
๋๋ ๋ํ CPU๋ง์ผ๋ก ํ๋ จํ๋๋ฐ ๋ ๋ค ์ ์์ ์ผ๋ก ์๋ํฉ๋๋ค. ๊ทธ๋์ ๋๋ ๊ทธ ์ด์ ๊ฐ CUDA ์ค๋ฒ ํค๋์ ๊ด๋ จ์ด ์๋ค๊ณ ์๊ฐํฉ๋๋ค. ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์์ด๋์ด๊ฐ ์์ต๋๊น?
์ฝ๋์ ์๋ฌด ๊ณณ์๋ cudnn.benchmark=True
๋ฅผ ์ค์ ํฉ๋๊น? ๊ทธ๊ฒ์ด ์๋ง๋ ๋ฒ์ธ์ผ ๊ฒ์
๋๋ค.
@fmassa ๊ฐ ์ฌ๊ธฐ์ ๋งํ๋ฏ์ด: https://discuss.pytorch.org/t/pytorch-performance/3079/7?u=smth
๋ฒค์น๋งํฌ ๋ชจ๋์์ ๊ฐ ์ ๋ ฅ ํฌ๊ธฐ์ ๋ํด cudnn์ ํน์ ๊ฒฝ์ฐ์ ๋ํด ๊ฐ์ฅ ๋น ๋ฅธ ์๊ณ ๋ฆฌ์ฆ์ ์ถ๋ก ํ๊ธฐ ์ํด ๋ง์ ๊ณ์ฐ์ ์ํํ๊ณ ๊ฒฐ๊ณผ๋ฅผ ์บ์ํฉ๋๋ค. ์ด๊ฒ์ ์ฝ๊ฐ์ ์ค๋ฒํค๋๋ฅผ ๊ฐ์ ธ์ค๊ณ , ์ ๋ ฅ ์น์๊ฐ ํญ์ ๋ณ๊ฒฝ๋๋ ๊ฒฝ์ฐ ๋ฒค์น๋งํฌ๋ฅผ ์ฌ์ฉํ๋ฉด ์ด ์ค๋ฒํค๋๋ก ์ธํด ์ค์ ๋ก ์๋๊ฐ ๋๋ ค์ง๋๋ค.
์์ํ. ๋๋ ์ค์ ์ฃผ์ ์ฒ๋ฆฌํ๊ณ ๋ ๋ค ์ด์ ์ ์์ ์ผ๋ก ์๋ํฉ๋๋ค.
๊ฐ์ฌ ํด์.
๊ฐ์ฅ ์ ์ฉํ ๋๊ธ
@fmassa ๊ฐ ์ฌ๊ธฐ์ ๋งํ๋ฏ์ด: https://discuss.pytorch.org/t/pytorch-performance/3079/7?u=smth