Torch.nn.CrossEntropyLoss ์ ๋ณด,
์ ๋ pytorch๋ฅผ ๋ฐฐ์ฐ๊ณ ์์ผ๋ฉฐ anpr ํ๋ก์ ํธ๋ฅผ ์งํ ์ค์
๋๋ค.
(https://github.com/matthewearl/deep-anpr,
http://matthewearl.github.io/2016/05/06/cnn-anpr/)
์ฐ์ต์ผ๋ก pytorch ํ๋ซํผ์ ์ด์ํ์ญ์์ค.
๋ฌธ์ ๊ฐ ์์ต๋๋ค. nn.CrossEntropyLoss()๋ฅผ ์์ค ํจ์๋ก ์ฌ์ฉํ๊ณ ์์ต๋๋ค.
๊ธฐ์ค=nn.CrossEntropyLoss()
๋ชจ๋ธ์ output.data๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
1.00000e-02 *
-2.5552 2.7582 2.5368 ... 5.6184 1.2288 -0.0076
-0.7033 1.3167 -1.0966 ... 4.7249 1.3217 1.8367
-0.7592 1.4777 1.8095 ... 0.8733 1.2417 1.1521
-0.1040 -0.7054 -3.4862 ... 4.7703 2.9595 1.4263
[torch.FloatTensor ํฌ๊ธฐ 4x253]
๊ทธ๋ฆฌ๊ณ target.data๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
1 0 0 ... 0 0 0
1 0 0 ... 0 0 0
1 0 0 ... 0 0 0
1 0 0 ... 0 0 0
[torch.DoubleTensor ํฌ๊ธฐ 4x253]
๋ด๊ฐ ์ ํํ ๋:
์์ค=๊ธฐ์ค(์ถ๋ ฅ, ๋ชฉํ)
์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค. ์ ๋ณด๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
TypeError: FloatClassNLLCriterion_updateOutput์ด ์๋ชป๋ ์ธ์ ์กฐํฉ์ ์์ ํ์ต๋๋ค. (int, torch.FloatTensor, torch.DoubleTensor, torch.FloatTensor, bool, NoneType, torch.FloatTensor)๋ฅผ ์ป์์ง๋ง ์์๋จ(int state, torch.FloatTensor ์
๋ ฅ, torch.LongTensor ๋์ , torch.FloatTensor ์ถ๋ ฅ, bool sizeAverage, [torch.FloatTensor weights ๋๋ None], torch.FloatTensor total_weight)
'์์๋ ํ ์น.LongTensor','torch.DoubleTensor๊ฐ ์์', ํ์ง๋ง ๋์์ LongTensor๋ก ๋ณํํ๋ ๊ฒฝ์ฐ:
ํ ์น.LongTensor(numpy.array(targets.data.numpy(),numpy.long))
loss=criterion(output,targets)์ ํธ์ถํ๋ฉด ์ค๋ฅ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
๋ฐํ์ ์ค๋ฅ: /data/users/soumith/miniconda2/conda-bld/pytorch-0.1.10_1488752595704/work/torch/lib/THNN/generic/ClassNLLCriterion.c:20์์ ๋ค์ค ๋์์ด ์ง์๋์ง ์์ต๋๋ค.
๋ด ๋ง์ง๋ง ์ด๋์ mnist, pytorch์ ์์
๋๋ค. ์ฝ๊ฐ ์์ ํ์ต๋๋ค. batch_size๋ 4์ด๊ณ ์์ค ํจ์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
loss = F.nll_loss(์ถ๋ ฅ, ๋ ์ด๋ธ)
output.data:
-2.3220 -2.1229 -2.3395 -2.3391 -2.5270 -2.3269 -2.1055 -2.2321 -2.4943 -2.2996
-2.3653 -2.2034 -2.4437 -2.2708 -2.5114 -2.3286 -2.1921 -2.1771 -2.3343 -2.2533
-2.2809 -2.2119 -2.3872 -2.2190 -2.4610 -2.2946 -2.2053 -2.3192 -2.3674 -2.3100
-2.3715 -2.1455 -2.4199 -2.4177 -2.4565 -2.2812 -2.2467 -2.1144 -2.3321 -2.3009
[torch.FloatTensor ํฌ๊ธฐ 4x10]
๋ ์ด๋ธ.๋ฐ์ดํฐ:
8
6
0
1
[torch.LongTensor ํฌ๊ธฐ 4]
๋ ์ด๋ธ์ ์ ๋ ฅ ์ด๋ฏธ์ง์ ๊ฒฝ์ฐ ๋จ์ผ ์์์ฌ์ผ ํ๋ฉฐ ์์ ์์์๋ 253๊ฐ์ ์ซ์๊ฐ ์๊ณ 'mnist'์์๋ ํ๋์ ์ซ์๋ง ์์ผ๋ฏ๋ก ์ถ๋ ฅ์ ๋ชจ์์ด ๋ ์ด๋ธ๊ณผ ๋ค๋ฆ ๋๋ค.
๋๋ tensorflow ๋งค๋ด์ผ, tf.nn.softmax_cross_entropy_with_logits๋ฅผ ๊ฒํ ํฉ๋๋ค.
'๋ก์งํธ์ ๋ ์ด๋ธ์ ๋์ผํ ๋ชจ์[batch_size, num_classes] ๋ฐ ๋์ผํ dtype(float32 ๋๋ float64)์ ๊ฐ์ ธ์ผ ํฉ๋๋ค.'
์ด ๊ฒฝ์ฐ pytorch๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๊น? ์๋๋ฉด ์ด๋ป๊ฒ ํ ์ ์์ต๋๊น?
๋ง์ ๊ฐ์ฌ
http://pytorch.org/docs/nn.html#crossentropyloss
CrossEntropyLoss ๋ชจ์:
์
๋ ฅ: (N,C) ์ฌ๊ธฐ์ C = ํด๋์ค ์
๋์: (N) ์ฌ๊ธฐ์ ๊ฐ ๊ฐ์ 0 <= target[i] <= C-1
๋ค๋ฅธ ์ ํ์ด ์์ต๋๊น?
TF ๋ฌธ์๊ฐ PyTorch API์ ๋ํด ๋ฐฐ์ธ ์ ์๋ ๊ฐ์ฅ ์ข์ ์ฅ์๋ผ๊ณ ๋งํ์ง๋ ์๊ฒ ์ต๋๋ค. ๐ ์ฐ๋ฆฌ๋ TF์ ํธํ๋๋๋ก ๋ ธ๋ ฅํ์ง ์์ผ๋ฉฐ, CrossEntropyLoss๋ ํด๋์ค ์ธ๋ฑ์ค ๋ฒกํฐ๋ฅผ ํ์ฉํฉ๋๋ค(์ด๋ ๊ฒ ํ๋ฉด ์ฌ์ฉํ์ ๋๋ณด๋ค ํจ์ฌ ๋น ๋ฅด๊ฒ ์คํํ ์ ์์ต๋๋ค. 1-ํซ ๋ฒกํฐ). ์ ๋ง ํ์ํ ๊ฒฝ์ฐ ๋ ํํ ์ฌ์ด๋ฅผ ๋ณํํ๋ ๊ฒ์ด ๊ฐ๋จํด์ผ ํฉ๋๋ค.
๋ฒ๊ทธ ๋ณด๊ณ ์์๋ง GitHub ๋ฌธ์ ๋ฅผ ์ฌ์ฉํ๊ณ ์์ต๋๋ค. ์ง๋ฌธ์ด ์๋ ๊ฒฝ์ฐ ํฌ๋ผ ์์ ์ง๋ฌธํ์ญ์์ค.
๊ฐ์ฌํฉ๋๋ค. ์-ํซ ํด๋์ค ์ธ์ฝ๋ฉ ํ๋ ฌ์ ์ ์ ๋ฒกํฐ๋ก ๋ณํํ๋ฉด CrossEntropyLoss ๋ฌธ์ ๊ฐ ํด๊ฒฐ๋์์ต๋๋ค!
๊ฐ์ฅ ์ ์ฉํ ๋๊ธ
TF ๋ฌธ์๊ฐ PyTorch API์ ๋ํด ๋ฐฐ์ธ ์ ์๋ ๊ฐ์ฅ ์ข์ ์ฅ์๋ผ๊ณ ๋งํ์ง๋ ์๊ฒ ์ต๋๋ค. ๐ ์ฐ๋ฆฌ๋ TF์ ํธํ๋๋๋ก ๋ ธ๋ ฅํ์ง ์์ผ๋ฉฐ, CrossEntropyLoss๋ ํด๋์ค ์ธ๋ฑ์ค ๋ฒกํฐ๋ฅผ ํ์ฉํฉ๋๋ค(์ด๋ ๊ฒ ํ๋ฉด ์ฌ์ฉํ์ ๋๋ณด๋ค ํจ์ฌ ๋น ๋ฅด๊ฒ ์คํํ ์ ์์ต๋๋ค. 1-ํซ ๋ฒกํฐ). ์ ๋ง ํ์ํ ๊ฒฝ์ฐ ๋ ํํ ์ฌ์ด๋ฅผ ๋ณํํ๋ ๊ฒ์ด ๊ฐ๋จํด์ผ ํฉ๋๋ค.
๋ฒ๊ทธ ๋ณด๊ณ ์์๋ง GitHub ๋ฌธ์ ๋ฅผ ์ฌ์ฉํ๊ณ ์์ต๋๋ค. ์ง๋ฌธ์ด ์๋ ๊ฒฝ์ฐ ํฌ๋ผ ์์ ์ง๋ฌธํ์ญ์์ค.