Pytorch: ํ† ์น˜.nn.CrossEntropyLoss ์ •๋ณด

์— ๋งŒ๋“  2017๋…„ 04์›” 14์ผ  ยท  3์ฝ”๋ฉ˜ํŠธ  ยท  ์ถœ์ฒ˜: pytorch/pytorch

Torch.nn.CrossEntropyLoss ์ •๋ณด,
์ €๋Š” pytorch๋ฅผ ๋ฐฐ์šฐ๊ณ  ์žˆ์œผ๋ฉฐ anpr ํ”„๋กœ์ ํŠธ๋ฅผ ์ง„ํ–‰ ์ค‘์ž…๋‹ˆ๋‹ค.
(https://github.com/matthewearl/deep-anpr,
http://matthewearl.github.io/2016/05/06/cnn-anpr/)
์—ฐ์Šต์œผ๋กœ pytorch ํ”Œ๋žซํผ์— ์ด์‹ํ•˜์‹ญ์‹œ์˜ค.

๋ฌธ์ œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. nn.CrossEntropyLoss()๋ฅผ ์†์‹ค ํ•จ์ˆ˜๋กœ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
๊ธฐ์ค€=nn.CrossEntropyLoss()

๋ชจ๋ธ์˜ output.data๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
1.00000e-02 *
-2.5552 2.7582 2.5368 ... 5.6184 1.2288 -0.0076
-0.7033 1.3167 -1.0966 ... 4.7249 1.3217 1.8367
-0.7592 1.4777 1.8095 ... 0.8733 1.2417 1.1521
-0.1040 -0.7054 -3.4862 ... 4.7703 2.9595 1.4263
[torch.FloatTensor ํฌ๊ธฐ 4x253]

๊ทธ๋ฆฌ๊ณ  target.data๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
1 0 0 ... 0 0 0
1 0 0 ... 0 0 0
1 0 0 ... 0 0 0
1 0 0 ... 0 0 0
[torch.DoubleTensor ํฌ๊ธฐ 4x253]

๋‚ด๊ฐ€ ์ „ํ™”ํ•  ๋•Œ:
์†์‹ค=๊ธฐ์ค€(์ถœ๋ ฅ, ๋ชฉํ‘œ)
์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค. ์ •๋ณด๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
TypeError: FloatClassNLLCriterion_updateOutput์ด ์ž˜๋ชป๋œ ์ธ์ˆ˜ ์กฐํ•ฉ์„ ์ˆ˜์‹ ํ–ˆ์Šต๋‹ˆ๋‹ค. (int, torch.FloatTensor, torch.DoubleTensor, torch.FloatTensor, bool, NoneType, torch.FloatTensor)๋ฅผ ์–ป์—ˆ์ง€๋งŒ ์˜ˆ์ƒ๋จ(int state, torch.FloatTensor ์ž…๋ ฅ, torch.LongTensor ๋Œ€์ƒ , torch.FloatTensor ์ถœ๋ ฅ, bool sizeAverage, [torch.FloatTensor weights ๋˜๋Š” None], torch.FloatTensor total_weight)

'์˜ˆ์ƒ๋œ ํ† ์น˜.LongTensor','torch.DoubleTensor๊ฐ€ ์žˆ์Œ', ํ•˜์ง€๋งŒ ๋Œ€์ƒ์„ LongTensor๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ๊ฒฝ์šฐ:
ํ† ์น˜.LongTensor(numpy.array(targets.data.numpy(),numpy.long))
loss=criterion(output,targets)์„ ํ˜ธ์ถœํ•˜๋ฉด ์˜ค๋ฅ˜๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
๋Ÿฐํƒ€์ž„ ์˜ค๋ฅ˜: /data/users/soumith/miniconda2/conda-bld/pytorch-0.1.10_1488752595704/work/torch/lib/THNN/generic/ClassNLLCriterion.c:20์—์„œ ๋‹ค์ค‘ ๋Œ€์ƒ์ด ์ง€์›๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

๋‚ด ๋งˆ์ง€๋ง‰ ์šด๋™์€ mnist, pytorch์˜ ์˜ˆ์ž…๋‹ˆ๋‹ค. ์•ฝ๊ฐ„ ์ˆ˜์ •ํ–ˆ์Šต๋‹ˆ๋‹ค. batch_size๋Š” 4์ด๊ณ  ์†์‹ค ํ•จ์ˆ˜๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
loss = F.nll_loss(์ถœ๋ ฅ, ๋ ˆ์ด๋ธ”)
output.data:
-2.3220 -2.1229 -2.3395 -2.3391 -2.5270 -2.3269 -2.1055 -2.2321 -2.4943 -2.2996
-2.3653 -2.2034 -2.4437 -2.2708 -2.5114 -2.3286 -2.1921 -2.1771 -2.3343 -2.2533
-2.2809 -2.2119 -2.3872 -2.2190 -2.4610 -2.2946 -2.2053 -2.3192 -2.3674 -2.3100
-2.3715 -2.1455 -2.4199 -2.4177 -2.4565 -2.2812 -2.2467 -2.1144 -2.3321 -2.3009
[torch.FloatTensor ํฌ๊ธฐ 4x10]

๋ ˆ์ด๋ธ”.๋ฐ์ดํ„ฐ:
8
6
0
1
[torch.LongTensor ํฌ๊ธฐ 4]

๋ ˆ์ด๋ธ”์€ ์ž…๋ ฅ ์ด๋ฏธ์ง€์˜ ๊ฒฝ์šฐ ๋‹จ์ผ ์š”์†Œ์—ฌ์•ผ ํ•˜๋ฉฐ ์œ„์˜ ์˜ˆ์—์„œ๋Š” 253๊ฐœ์˜ ์ˆซ์ž๊ฐ€ ์žˆ๊ณ  'mnist'์—์„œ๋Š” ํ•˜๋‚˜์˜ ์ˆซ์ž๋งŒ ์žˆ์œผ๋ฏ€๋กœ ์ถœ๋ ฅ์˜ ๋ชจ์–‘์ด ๋ ˆ์ด๋ธ”๊ณผ ๋‹ค๋ฆ…๋‹ˆ๋‹ค.

๋‚˜๋Š” tensorflow ๋งค๋‰ด์–ผ, tf.nn.softmax_cross_entropy_with_logits๋ฅผ ๊ฒ€ํ† ํ•ฉ๋‹ˆ๋‹ค.
'๋กœ์ง€ํŠธ์™€ ๋ ˆ์ด๋ธ”์€ ๋™์ผํ•œ ๋ชจ์–‘[batch_size, num_classes] ๋ฐ ๋™์ผํ•œ dtype(float32 ๋˜๋Š” float64)์„ ๊ฐ€์ ธ์•ผ ํ•ฉ๋‹ˆ๋‹ค.'

์ด ๊ฒฝ์šฐ pytorch๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๊นŒ? ์•„๋‹ˆ๋ฉด ์–ด๋–ป๊ฒŒ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๊นŒ?
๋งŽ์€ ๊ฐ์‚ฌ

๊ฐ€์žฅ ์œ ์šฉํ•œ ๋Œ“๊ธ€

TF ๋ฌธ์„œ๊ฐ€ PyTorch API์— ๋Œ€ํ•ด ๋ฐฐ์šธ ์ˆ˜ ์žˆ๋Š” ๊ฐ€์žฅ ์ข‹์€ ์žฅ์†Œ๋ผ๊ณ  ๋งํ•˜์ง€๋Š” ์•Š๊ฒ ์Šต๋‹ˆ๋‹ค. ๐Ÿ™‚ ์šฐ๋ฆฌ๋Š” TF์™€ ํ˜ธํ™˜๋˜๋„๋ก ๋…ธ๋ ฅํ•˜์ง€ ์•Š์œผ๋ฉฐ, CrossEntropyLoss๋Š” ํด๋ž˜์Šค ์ธ๋ฑ์Šค ๋ฒกํ„ฐ๋ฅผ ํ—ˆ์šฉํ•ฉ๋‹ˆ๋‹ค(์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ์‚ฌ์šฉํ–ˆ์„ ๋•Œ๋ณด๋‹ค ํ›จ์”ฌ ๋น ๋ฅด๊ฒŒ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. 1-ํ•ซ ๋ฒกํ„ฐ). ์ •๋ง ํ•„์š”ํ•œ ๊ฒฝ์šฐ ๋‘ ํ‘œํ˜„ ์‚ฌ์ด๋ฅผ ๋ณ€ํ™˜ํ•˜๋Š” ๊ฒƒ์ด ๊ฐ„๋‹จํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๋ฒ„๊ทธ ๋ณด๊ณ ์„œ์—๋งŒ GitHub ๋ฌธ์ œ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์งˆ๋ฌธ์ด ์žˆ๋Š” ๊ฒฝ์šฐ ํฌ๋Ÿผ ์—์„œ ์งˆ๋ฌธํ•˜์‹ญ์‹œ์˜ค.

๋ชจ๋“  3 ๋Œ“๊ธ€

http://pytorch.org/docs/nn.html#crossentropyloss
CrossEntropyLoss ๋ชจ์–‘:
์ž…๋ ฅ: (N,C) ์—ฌ๊ธฐ์„œ C = ํด๋ž˜์Šค ์ˆ˜
๋Œ€์ƒ: (N) ์—ฌ๊ธฐ์„œ ๊ฐ ๊ฐ’์€ 0 <= target[i] <= C-1

๋‹ค๋ฅธ ์„ ํƒ์ด ์žˆ์Šต๋‹ˆ๊นŒ?

TF ๋ฌธ์„œ๊ฐ€ PyTorch API์— ๋Œ€ํ•ด ๋ฐฐ์šธ ์ˆ˜ ์žˆ๋Š” ๊ฐ€์žฅ ์ข‹์€ ์žฅ์†Œ๋ผ๊ณ  ๋งํ•˜์ง€๋Š” ์•Š๊ฒ ์Šต๋‹ˆ๋‹ค. ๐Ÿ™‚ ์šฐ๋ฆฌ๋Š” TF์™€ ํ˜ธํ™˜๋˜๋„๋ก ๋…ธ๋ ฅํ•˜์ง€ ์•Š์œผ๋ฉฐ, CrossEntropyLoss๋Š” ํด๋ž˜์Šค ์ธ๋ฑ์Šค ๋ฒกํ„ฐ๋ฅผ ํ—ˆ์šฉํ•ฉ๋‹ˆ๋‹ค(์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ์‚ฌ์šฉํ–ˆ์„ ๋•Œ๋ณด๋‹ค ํ›จ์”ฌ ๋น ๋ฅด๊ฒŒ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. 1-ํ•ซ ๋ฒกํ„ฐ). ์ •๋ง ํ•„์š”ํ•œ ๊ฒฝ์šฐ ๋‘ ํ‘œํ˜„ ์‚ฌ์ด๋ฅผ ๋ณ€ํ™˜ํ•˜๋Š” ๊ฒƒ์ด ๊ฐ„๋‹จํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๋ฒ„๊ทธ ๋ณด๊ณ ์„œ์—๋งŒ GitHub ๋ฌธ์ œ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์งˆ๋ฌธ์ด ์žˆ๋Š” ๊ฒฝ์šฐ ํฌ๋Ÿผ ์—์„œ ์งˆ๋ฌธํ•˜์‹ญ์‹œ์˜ค.

๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค. ์›-ํ•ซ ํด๋ž˜์Šค ์ธ์ฝ”๋”ฉ ํ–‰๋ ฌ์„ ์ •์ˆ˜ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜ํ•˜๋ฉด CrossEntropyLoss ๋ฌธ์ œ๊ฐ€ ํ•ด๊ฒฐ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!

์ด ํŽ˜์ด์ง€๊ฐ€ ๋„์›€์ด ๋˜์—ˆ๋‚˜์š”?
0 / 5 - 0 ๋“ฑ๊ธ‰