Pytorch: ΠΎ torch.nn.CrossEntropyLoss

Π‘ΠΎΠ·Π΄Π°Π½Π½Ρ‹ΠΉ Π½Π° 14 Π°ΠΏΡ€. 2017  Β·  3ΠšΠΎΠΌΠΌΠ΅Π½Ρ‚Π°Ρ€ΠΈΠΈ  Β·  Π˜ΡΡ‚ΠΎΡ‡Π½ΠΈΠΊ: pytorch/pytorch

ΠΎ torch.nn.CrossEntropyLoss,
я ΠΈΠ·ΡƒΡ‡Π°ΡŽ pytorch ΠΈ Π±Π΅Ρ€Ρƒ ΠΏΡ€ΠΎΠ΅ΠΊΡ‚ anpr
(https://github.com/matthewearl/deep-anpr,
http://matthewearl.github.io/2016/05/06/cnn-anpr/)
Π² качСствС упраТнСния пСрСсадитС Π΅Π³ΠΎ Π½Π° ΠΏΠ»Π°Ρ‚Ρ„ΠΎΡ€ΠΌΡƒ pytorch.

Π΅ΡΡ‚ΡŒ ΠΏΡ€ΠΎΠ±Π»Π΅ΠΌΠ°, я ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΡƒΡŽ nn.CrossEntropyLoss() ΠΊΠ°ΠΊ Ρ„ΡƒΠ½ΠΊΡ†ΠΈΡŽ ΠΏΠΎΡ‚Π΅Ρ€ΡŒ:
ΠΊΡ€ΠΈΡ‚Π΅Ρ€ΠΈΠΉ=nn.CrossEntropyLoss()

Π²Ρ‹Ρ…ΠΎΠ΄Π½Ρ‹Π΅ Π΄Π°Π½Π½Ρ‹Π΅ ΠΌΠΎΠ΄Π΅Π»ΠΈ:
1.00000e-02 *
-2,5552 2,7582 2,5368 ... 5,6184 1,2288 -0,0076
-0,7033 1,3167 -1,0966 ... 4,7249 1,3217 1,8367
-0,7592 1,4777 1,8095 ... 0,8733 1,2417 1,1521
-0,1040 -0,7054 -3,4862 ... 4,7703 2,9595 1,4263
[torch.FloatTensor Ρ€Π°Π·ΠΌΠ΅Ρ€ΠΎΠΌ 4x253]

ΠΈ target.data:
1 0 0 ... 0 0 0
1 0 0 ... 0 0 0
1 0 0 ... 0 0 0
1 0 0 ... 0 0 0
[torch.DoubleTensor Ρ€Π°Π·ΠΌΠ΅Ρ€ΠΎΠΌ 4x253]

когда я звоню:
потСря = ΠΊΡ€ΠΈΡ‚Π΅Ρ€ΠΈΠΉ (Π²Ρ‹Ρ…ΠΎΠ΄, Ρ†Π΅Π»ΠΈ)
ΠΏΡ€ΠΎΠΈΠ·ΠΎΡˆΠ»Π° ошибка, информация такая:
TypeError: FloatClassNLLCriterion_updateOutput ΠΏΠΎΠ»ΡƒΡ‡ΠΈΠ» Π½Π΅Π΄ΠΎΠΏΡƒΡΡ‚ΠΈΠΌΡƒΡŽ ΠΊΠΎΠΌΠ±ΠΈΠ½Π°Ρ†ΠΈΡŽ Π°Ρ€Π³ΡƒΠΌΠ΅Π½Ρ‚ΠΎΠ² β€” ΠΏΠΎΠ»ΡƒΡ‡ΠΈΠ» (int, torch.FloatTensor, torch.DoubleTensor, torch.FloatTensor, bool, NoneType, torch.FloatTensor), Π½ΠΎ ΠΎΠΆΠΈΠ΄Π°Π» (int state, torch.FloatTensor input, torch.LongTensor target) , torch.FloatTensor output, bool sizeAverage, [torch.FloatTensor weights or None], torch.FloatTensor total_weight)

'ΠΎΠΆΠΈΠ΄Π°Π΅ΠΌΡ‹ΠΉ torch.LongTensor', 'ΠΏΠΎΠ»ΡƒΡ‡ΠΈΠ» torch.DoubleTensor', Π½ΠΎ Ссли я ΠΊΠΎΠ½Π²Π΅Ρ€Ρ‚ΠΈΡ€ΡƒΡŽ Ρ†Π΅Π»ΠΈ Π² LongTensor:
torch.LongTensor(numpy.array(targets.data.numpy(),numpy.long))
потСря Π²Ρ‹Π·ΠΎΠ²Π° = ΠΊΡ€ΠΈΡ‚Π΅Ρ€ΠΈΠΉ (Π²Ρ‹Ρ…ΠΎΠ΄, Ρ†Π΅Π»ΠΈ), ошибка:
Ошибка выполнСния: ΠΌΠ½ΠΎΠ³ΠΎΡ†Π΅Π»Π΅Π²ΠΎΠΉ ΠΎΠ±ΡŠΠ΅ΠΊΡ‚ Π½Π΅ поддСрТиваСтся Π² /data/users/soumith/miniconda2/conda-bld/pytorch-0.1.10_1488752595704/work/torch/lib/THNN/generic/ClassNLLCriterion.c:20

ΠΌΠΎΠ΅ послСднСС ΡƒΠΏΡ€Π°ΠΆΠ½Π΅Π½ΠΈΠ΅ - mnist, ΠΏΡ€ΠΈΠΌΠ΅Ρ€ ΠΈΠ· pytorch, я Π½Π΅ΠΌΠ½ΠΎΠ³ΠΎ ΠΌΠΎΠ΄ΠΈΡ„ΠΈΡ†ΠΈΡ€ΠΎΠ²Π°Π», Ρ€Π°Π·ΠΌΠ΅Ρ€ ΠΏΠ°Ρ€Ρ‚ΠΈΠΈ Ρ€Π°Π²Π΅Π½ 4, функция ΠΏΠΎΡ‚Π΅Ρ€ΡŒ:
потСря = F.nll_loss (Π²Ρ‹Ρ…ΠΎΠ΄Π½Ρ‹Π΅ Π΄Π°Π½Π½Ρ‹Π΅, ΠΌΠ΅Ρ‚ΠΊΠΈ)
Π²Ρ‹Ρ…ΠΎΠ΄Ρ‹.Π΄Π°Π½Π½Ρ‹Π΅:
-2,3220 -2,1229 -2,3395 -2,3391 -2,5270 -2,3269 -2,1055 -2,2321 -2,4943 -2,2996
-2,3653 -2,2034 -2,4437 -2,2708 -2,5114 -2,3286 -2,1921 -2,1771 -2,3343 -2,2533
-2,2809 -2,2119 -2,3872 -2,2190 -2,4610 -2,2946 -2,2053 -2,3192 -2,3674 -2,3100
-2,3715 -2,1455 -2,4199 -2,4177 -2,4565 -2,2812 -2,2467 -2,1144 -2,3321 -2,3009
[torch.FloatTensor Ρ€Π°Π·ΠΌΠ΅Ρ€ΠΎΠΌ 4x10]

этикСтки.Π΄Π°Π½Π½Ρ‹Π΅:
8
6
0
1
[torch.LongTensor Ρ€Π°Π·ΠΌΠ΅Ρ€Π° 4]

ΠΌΠ΅Ρ‚ΠΊΠΈ для Π²Ρ…ΠΎΠ΄Π½ΠΎΠ³ΠΎ изобраТСния Π΄ΠΎΠ»ΠΆΠ½Ρ‹ Π±Ρ‹Ρ‚ΡŒ ΠΎΠ΄Π½ΠΈΠΌ элСмСнтом, Π² Π²Π΅Ρ€Ρ…Π½Π΅ΠΌ ΠΏΡ€ΠΈΠΌΠ΅Ρ€Π΅ 253 числа, Π° Π² Β«mnistΒ» Π΅ΡΡ‚ΡŒ Ρ‚ΠΎΠ»ΡŒΠΊΠΎ ΠΎΠ΄Π½ΠΎ число, Ρ„ΠΎΡ€ΠΌΠ° Π²Ρ‹Ρ…ΠΎΠ΄ΠΎΠ² отличаСтся ΠΎΡ‚ ΠΌΠ΅Ρ‚ΠΎΠΊ.

я ΠΏΡ€ΠΎΡΠΌΠ°Ρ‚Ρ€ΠΈΠ²Π°ΡŽ руководство ΠΏΠΎ Ρ‚Π΅Π½Π·ΠΎΡ€Π½ΠΎΠΌΡƒ ΠΏΠΎΡ‚ΠΎΠΊΡƒ, tf.nn.softmax_cross_entropy_with_logits,
Β«Π›ΠΎΠ³ΠΈΡ‚Ρ‹ ΠΈ ΠΌΠ΅Ρ‚ΠΊΠΈ Π΄ΠΎΠ»ΠΆΠ½Ρ‹ ΠΈΠΌΠ΅Ρ‚ΡŒ ΠΎΠ΄ΠΈΠ½Π°ΠΊΠΎΠ²ΡƒΡŽ Ρ„ΠΎΡ€ΠΌΡƒ [batch_size, num_classes] ΠΈ ΠΎΠ΄ΠΈΠ½ ΠΈ Ρ‚ΠΎΡ‚ ΠΆΠ΅ Ρ‚ΠΈΠΏ dtype (Π»ΠΈΠ±ΠΎ float32, Π»ΠΈΠ±ΠΎ float64)Β».

Π˜Ρ‚Π°ΠΊ, ΠΌΠΎΠ³Ρƒ Π»ΠΈ я ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΠΎΠ²Π°Ρ‚ΡŒ pytorch Π² этом случаС ΠΈΠ»ΠΈ ΠΊΠ°ΠΊ ΠΌΠ½Π΅ это ΡΠ΄Π΅Π»Π°Ρ‚ΡŒ?
большоС спасибо

Π‘Π°ΠΌΡ‹ΠΉ ΠΏΠΎΠ»Π΅Π·Π½Ρ‹ΠΉ ΠΊΠΎΠΌΠΌΠ΅Π½Ρ‚Π°Ρ€ΠΈΠΉ

Π― Π±Ρ‹ Π½Π΅ сказал, Ρ‡Ρ‚ΠΎ Π΄ΠΎΠΊΡƒΠΌΠ΅Π½Ρ‚Ρ‹ TF β€” Π»ΡƒΡ‡ΡˆΠ΅Π΅ мСсто для изучСния API PyTorch πŸ™‚ ΠœΡ‹ Π½Π΅ пытаСмся Π±Ρ‹Ρ‚ΡŒ совмСстимыми с TF, ΠΈ наш CrossEntropyLoss ΠΏΡ€ΠΈΠ½ΠΈΠΌΠ°Π΅Ρ‚ Π²Π΅ΠΊΡ‚ΠΎΡ€ индСксов классов (это позволяСт Π΅ΠΌΡƒ Ρ€Π°Π±ΠΎΡ‚Π°Ρ‚ΡŒ Π½Π°ΠΌΠ½ΠΎΠ³ΠΎ быстрСС, Ρ‡Π΅ΠΌ Ссли Π±Ρ‹ ΠΎΠ½ использовал 1-горячиС Π²Π΅ΠΊΡ‚ΠΎΡ€Ρ‹). ΠŸΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ ΠΌΠ΅ΠΆΠ΄Ρƒ ΠΎΠ±ΠΎΠΈΠΌΠΈ прСдставлСниями Π΄ΠΎΠ»ΠΆΠ½ΠΎ Π±Ρ‹Ρ‚ΡŒ простым, Ссли Π²Π°ΠΌ это Π΄Π΅ΠΉΡΡ‚Π²ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎ Π½ΡƒΠΆΠ½ΠΎ.

ΠžΠ±Ρ€Π°Ρ‚ΠΈΡ‚Π΅ Π²Π½ΠΈΠΌΠ°Π½ΠΈΠ΅, Ρ‡Ρ‚ΠΎ ΠΌΡ‹ ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅ΠΌ вопросы GitHub Ρ‚ΠΎΠ»ΡŒΠΊΠΎ для ΠΎΡ‚Ρ‡Π΅Ρ‚ΠΎΠ² ΠΎΠ± ΠΎΡˆΠΈΠ±ΠΊΠ°Ρ…. Если Ρƒ вас Π΅ΡΡ‚ΡŒ вопросы, Π·Π°Π΄Π°Π²Π°ΠΉΡ‚Π΅ ΠΈΡ… Π½Π° Π½Π°ΡˆΠΈΡ… Ρ„ΠΎΡ€ΡƒΠΌΠ°Ρ… .

ВсС 3 ΠšΠΎΠΌΠΌΠ΅Π½Ρ‚Π°Ρ€ΠΈΠΉ

http://pytorch.org/docs/nn.html#crossentropyloss
Π€ΠΎΡ€ΠΌΠ° CrossEntropyLoss:
Π’Ρ…ΠΎΠ΄: (N,C), Π³Π΄Π΅ C = количСство классов
ЦСль: (N), Π³Π΄Π΅ ΠΊΠ°ΠΆΠ΄ΠΎΠ΅ Π·Π½Π°Ρ‡Π΅Π½ΠΈΠ΅ Ρ€Π°Π²Π½ΠΎ 0 <= target[i] <= C-1

Ρƒ мСня Π΅ΡΡ‚ΡŒ Π΄Ρ€ΡƒΠ³ΠΎΠΉ Π²Ρ‹Π±ΠΎΡ€?

Π― Π±Ρ‹ Π½Π΅ сказал, Ρ‡Ρ‚ΠΎ Π΄ΠΎΠΊΡƒΠΌΠ΅Π½Ρ‚Ρ‹ TF β€” Π»ΡƒΡ‡ΡˆΠ΅Π΅ мСсто для изучСния API PyTorch πŸ™‚ ΠœΡ‹ Π½Π΅ пытаСмся Π±Ρ‹Ρ‚ΡŒ совмСстимыми с TF, ΠΈ наш CrossEntropyLoss ΠΏΡ€ΠΈΠ½ΠΈΠΌΠ°Π΅Ρ‚ Π²Π΅ΠΊΡ‚ΠΎΡ€ индСксов классов (это позволяСт Π΅ΠΌΡƒ Ρ€Π°Π±ΠΎΡ‚Π°Ρ‚ΡŒ Π½Π°ΠΌΠ½ΠΎΠ³ΠΎ быстрСС, Ρ‡Π΅ΠΌ Ссли Π±Ρ‹ ΠΎΠ½ использовал 1-горячиС Π²Π΅ΠΊΡ‚ΠΎΡ€Ρ‹). ΠŸΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ ΠΌΠ΅ΠΆΠ΄Ρƒ ΠΎΠ±ΠΎΠΈΠΌΠΈ прСдставлСниями Π΄ΠΎΠ»ΠΆΠ½ΠΎ Π±Ρ‹Ρ‚ΡŒ простым, Ссли Π²Π°ΠΌ это Π΄Π΅ΠΉΡΡ‚Π²ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎ Π½ΡƒΠΆΠ½ΠΎ.

ΠžΠ±Ρ€Π°Ρ‚ΠΈΡ‚Π΅ Π²Π½ΠΈΠΌΠ°Π½ΠΈΠ΅, Ρ‡Ρ‚ΠΎ ΠΌΡ‹ ΠΈΡΠΏΠΎΠ»ΡŒΠ·ΡƒΠ΅ΠΌ вопросы GitHub Ρ‚ΠΎΠ»ΡŒΠΊΠΎ для ΠΎΡ‚Ρ‡Π΅Ρ‚ΠΎΠ² ΠΎΠ± ΠΎΡˆΠΈΠ±ΠΊΠ°Ρ…. Если Ρƒ вас Π΅ΡΡ‚ΡŒ вопросы, Π·Π°Π΄Π°Π²Π°ΠΉΡ‚Π΅ ΠΈΡ… Π½Π° Π½Π°ΡˆΠΈΡ… Ρ„ΠΎΡ€ΡƒΠΌΠ°Ρ… .

Бпасибо, ΠΏΡ€Π΅ΠΎΠ±Ρ€Π°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ ΠΌΠ°Ρ‚Ρ€ΠΈΡ†Ρ‹ кодирования ΠΎΠ΄Π½ΠΎΠ³ΠΎ горячСго класса Π² цСлочислСнный Π²Π΅ΠΊΡ‚ΠΎΡ€ устранило ΠΏΡ€ΠΎΠ±Π»Π΅ΠΌΡƒ CrossEntropyLoss для мСня!

Π‘Ρ‹Π»Π° Π»ΠΈ эта страница ΠΏΠΎΠ»Π΅Π·Π½ΠΎΠΉ?
0 / 5 - 0 Ρ€Π΅ΠΉΡ‚ΠΈΠ½Π³ΠΈ