Pytorch: ๊ฐ€๋ณ€ ์ž…๋ ฅ ํฌ๊ธฐ ํ›ˆ๋ จ์ด ๋Š๋ฆผ

์— ๋งŒ๋“  2017๋…„ 09์›” 17์ผ  ยท  3์ฝ”๋ฉ˜ํŠธ  ยท  ์ถœ์ฒ˜: pytorch/pytorch

resnet50์—์„œ ์ˆ˜์ •ํ•œ ๋ชจ๋ธ์ด ์žˆ๋Š”๋ฐ ๋งˆ์ง€๋ง‰ avgpool & fc ๋งŒ ์ œ๊ฑฐํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.
ํ›ˆ๋ จ ์ค‘์— ์ž…๋ ฅ ํฌ๊ธฐ๋ฅผ ๊ณ„์† ๋ณ€๊ฒฝํ•˜๋ฉด ์†๋„๊ฐ€ ๋Š๋ ค์ง€๋Š” ๊ฒƒ์„ ๋ฐœ๊ฒฌํ–ˆ์Šต๋‹ˆ๋‹ค.

์ตœ์†Œ ์ฝ”๋“œ:

import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import numpy as np
from torch.autograd import Variable

# ... remove avgpool & fc from resnet50 here
net = resnet50()
net.cuda()
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
cudnn.benchmark = True

for i in range(10):
    h = np.random.randint(400,600)
    w = np.random.randint(400,600)
    # or fix h = w = 600
    x = Variable(torch.randn(1,3,h,w)).cuda()

    t1 = time.time()
    y = net(x)
    t2 = time.time()
    print(t2-t1)
  1. ์ž…๋ ฅ ํฌ๊ธฐ๋ฅผ [600,600]์œผ๋กœ ์ˆ˜์ •ํ•˜๋ฉด 8 Nvidia P40 ์‹œ์Šคํ…œ์—์„œ ์–ป์€ ๊ฐ’์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
3.14512705803
0.11568403244
0.0255229473114
0.0228650569916
0.0235478878021
0.0225219726562
0.0436158180237
0.0222969055176
0.0223350524902
0.0227248668671
  1. ์ž…๋ ฅ ํฌ๊ธฐ๋ฅผ [400,600]์—์„œ ๋ฌด์ž‘์œ„๋กœ ๋ณ€๊ฒฝํ•˜๋ฉด ๋‹ค์Œ์„ ์–ป์Šต๋‹ˆ๋‹ค.
3.12573313713
0.670918941498
2.32590889931
2.3486700058
2.31507301331
0.593285083771
0.68169093132
2.34181690216
0.597991943359
1.74615192413

๋‚˜๋Š” ๋˜ํ•œ CPU๋งŒ์œผ๋กœ ํ›ˆ๋ จํ–ˆ๋Š”๋ฐ ๋‘˜ ๋‹ค ์ •์ƒ์ ์œผ๋กœ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ๋‚˜๋Š” ๊ทธ ์ด์œ ๊ฐ€ CUDA ์˜ค๋ฒ„ ํ—ค๋“œ์™€ ๊ด€๋ จ์ด ์žˆ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค. ์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•  ์•„์ด๋””์–ด๊ฐ€ ์žˆ์Šต๋‹ˆ๊นŒ?

๊ฐ€์žฅ ์œ ์šฉํ•œ ๋Œ“๊ธ€

@fmassa ๊ฐ€ ์—ฌ๊ธฐ์— ๋งํ–ˆ๋“ฏ์ด: https://discuss.pytorch.org/t/pytorch-performance/3079/7?u=smth

๋ฒค์น˜๋งˆํฌ ๋ชจ๋“œ์—์„œ ๊ฐ ์ž…๋ ฅ ํฌ๊ธฐ์— ๋Œ€ํ•ด cudnn์€ ํŠน์ • ๊ฒฝ์šฐ์— ๋Œ€ํ•ด ๊ฐ€์žฅ ๋น ๋ฅธ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์ถ”๋ก ํ•˜๊ธฐ ์œ„ํ•ด ๋งŽ์€ ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ์บ์‹œํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ์•ฝ๊ฐ„์˜ ์˜ค๋ฒ„ํ—ค๋“œ๋ฅผ ๊ฐ€์ ธ์˜ค๊ณ , ์ž…๋ ฅ ์น˜์ˆ˜๊ฐ€ ํ•ญ์ƒ ๋ณ€๊ฒฝ๋˜๋Š” ๊ฒฝ์šฐ ๋ฒค์น˜๋งˆํฌ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์ด ์˜ค๋ฒ„ํ—ค๋“œ๋กœ ์ธํ•ด ์‹ค์ œ๋กœ ์†๋„๊ฐ€ ๋Š๋ ค์ง‘๋‹ˆ๋‹ค.

๋ชจ๋“  3 ๋Œ“๊ธ€

์ฝ”๋“œ์˜ ์•„๋ฌด ๊ณณ์—๋‚˜ cudnn.benchmark=True ๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๊นŒ? ๊ทธ๊ฒƒ์ด ์•„๋งˆ๋„ ๋ฒ”์ธ์ผ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

@fmassa ๊ฐ€ ์—ฌ๊ธฐ์— ๋งํ–ˆ๋“ฏ์ด: https://discuss.pytorch.org/t/pytorch-performance/3079/7?u=smth

๋ฒค์น˜๋งˆํฌ ๋ชจ๋“œ์—์„œ ๊ฐ ์ž…๋ ฅ ํฌ๊ธฐ์— ๋Œ€ํ•ด cudnn์€ ํŠน์ • ๊ฒฝ์šฐ์— ๋Œ€ํ•ด ๊ฐ€์žฅ ๋น ๋ฅธ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์ถ”๋ก ํ•˜๊ธฐ ์œ„ํ•ด ๋งŽ์€ ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ์บ์‹œํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ์•ฝ๊ฐ„์˜ ์˜ค๋ฒ„ํ—ค๋“œ๋ฅผ ๊ฐ€์ ธ์˜ค๊ณ , ์ž…๋ ฅ ์น˜์ˆ˜๊ฐ€ ํ•ญ์ƒ ๋ณ€๊ฒฝ๋˜๋Š” ๊ฒฝ์šฐ ๋ฒค์น˜๋งˆํฌ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์ด ์˜ค๋ฒ„ํ—ค๋“œ๋กœ ์ธํ•ด ์‹ค์ œ๋กœ ์†๋„๊ฐ€ ๋Š๋ ค์ง‘๋‹ˆ๋‹ค.

์‹œ์›ํ•œ. ๋‚˜๋Š” ์ค„์„ ์ฃผ์„ ์ฒ˜๋ฆฌํ•˜๊ณ  ๋‘˜ ๋‹ค ์ด์ œ ์ •์ƒ์ ์œผ๋กœ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค.
๊ฐ์‚ฌ ํ•ด์š”.

์ด ํŽ˜์ด์ง€๊ฐ€ ๋„์›€์ด ๋˜์—ˆ๋‚˜์š”?
0 / 5 - 0 ๋“ฑ๊ธ‰