ΠΎ torch.nn.CrossEntropyLoss,
Ρ ΠΈΠ·ΡΡΠ°Ρ pytorch ΠΈ Π±Π΅ΡΡ ΠΏΡΠΎΠ΅ΠΊΡ anpr
(https://github.com/matthewearl/deep-anpr,
http://matthewearl.github.io/2016/05/06/cnn-anpr/)
Π² ΠΊΠ°ΡΠ΅ΡΡΠ²Π΅ ΡΠΏΡΠ°ΠΆΠ½Π΅Π½ΠΈΡ ΠΏΠ΅ΡΠ΅ΡΠ°Π΄ΠΈΡΠ΅ Π΅Π³ΠΎ Π½Π° ΠΏΠ»Π°ΡΡΠΎΡΠΌΡ pytorch.
Π΅ΡΡΡ ΠΏΡΠΎΠ±Π»Π΅ΠΌΠ°, Ρ ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΡ nn.CrossEntropyLoss() ΠΊΠ°ΠΊ ΡΡΠ½ΠΊΡΠΈΡ ΠΏΠΎΡΠ΅ΡΡ:
ΠΊΡΠΈΡΠ΅ΡΠΈΠΉ=nn.CrossEntropyLoss()
Π²ΡΡ
ΠΎΠ΄Π½ΡΠ΅ Π΄Π°Π½Π½ΡΠ΅ ΠΌΠΎΠ΄Π΅Π»ΠΈ:
1.00000e-02 *
-2,5552 2,7582 2,5368 ... 5,6184 1,2288 -0,0076
-0,7033 1,3167 -1,0966 ... 4,7249 1,3217 1,8367
-0,7592 1,4777 1,8095 ... 0,8733 1,2417 1,1521
-0,1040 -0,7054 -3,4862 ... 4,7703 2,9595 1,4263
[torch.FloatTensor ΡΠ°Π·ΠΌΠ΅ΡΠΎΠΌ 4x253]
ΠΈ target.data:
1 0 0 ... 0 0 0
1 0 0 ... 0 0 0
1 0 0 ... 0 0 0
1 0 0 ... 0 0 0
[torch.DoubleTensor ΡΠ°Π·ΠΌΠ΅ΡΠΎΠΌ 4x253]
ΠΊΠΎΠ³Π΄Π° Ρ Π·Π²ΠΎΠ½Ρ:
ΠΏΠΎΡΠ΅ΡΡ = ΠΊΡΠΈΡΠ΅ΡΠΈΠΉ (Π²ΡΡ
ΠΎΠ΄, ΡΠ΅Π»ΠΈ)
ΠΏΡΠΎΠΈΠ·ΠΎΡΠ»Π° ΠΎΡΠΈΠ±ΠΊΠ°, ΠΈΠ½ΡΠΎΡΠΌΠ°ΡΠΈΡ ΡΠ°ΠΊΠ°Ρ:
TypeError: FloatClassNLLCriterion_updateOutput ΠΏΠΎΠ»ΡΡΠΈΠ» Π½Π΅Π΄ΠΎΠΏΡΡΡΠΈΠΌΡΡ ΠΊΠΎΠΌΠ±ΠΈΠ½Π°ΡΠΈΡ Π°ΡΠ³ΡΠΌΠ΅Π½ΡΠΎΠ² β ΠΏΠΎΠ»ΡΡΠΈΠ» (int, torch.FloatTensor, torch.DoubleTensor, torch.FloatTensor, bool, NoneType, torch.FloatTensor), Π½ΠΎ ΠΎΠΆΠΈΠ΄Π°Π» (int state, torch.FloatTensor input, torch.LongTensor target) , torch.FloatTensor output, bool sizeAverage, [torch.FloatTensor weights or None], torch.FloatTensor total_weight)
'ΠΎΠΆΠΈΠ΄Π°Π΅ΠΌΡΠΉ torch.LongTensor', 'ΠΏΠΎΠ»ΡΡΠΈΠ» torch.DoubleTensor', Π½ΠΎ Π΅ΡΠ»ΠΈ Ρ ΠΊΠΎΠ½Π²Π΅ΡΡΠΈΡΡΡ ΡΠ΅Π»ΠΈ Π² LongTensor:
torch.LongTensor(numpy.array(targets.data.numpy(),numpy.long))
ΠΏΠΎΡΠ΅ΡΡ Π²ΡΠ·ΠΎΠ²Π° = ΠΊΡΠΈΡΠ΅ΡΠΈΠΉ (Π²ΡΡ
ΠΎΠ΄, ΡΠ΅Π»ΠΈ), ΠΎΡΠΈΠ±ΠΊΠ°:
ΠΡΠΈΠ±ΠΊΠ° Π²ΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΡ: ΠΌΠ½ΠΎΠ³ΠΎΡΠ΅Π»Π΅Π²ΠΎΠΉ ΠΎΠ±ΡΠ΅ΠΊΡ Π½Π΅ ΠΏΠΎΠ΄Π΄Π΅ΡΠΆΠΈΠ²Π°Π΅ΡΡΡ Π² /data/users/soumith/miniconda2/conda-bld/pytorch-0.1.10_1488752595704/work/torch/lib/THNN/generic/ClassNLLCriterion.c:20
ΠΌΠΎΠ΅ ΠΏΠΎΡΠ»Π΅Π΄Π½Π΅Π΅ ΡΠΏΡΠ°ΠΆΠ½Π΅Π½ΠΈΠ΅ - mnist, ΠΏΡΠΈΠΌΠ΅Ρ ΠΈΠ· pytorch, Ρ Π½Π΅ΠΌΠ½ΠΎΠ³ΠΎ ΠΌΠΎΠ΄ΠΈΡΠΈΡΠΈΡΠΎΠ²Π°Π», ΡΠ°Π·ΠΌΠ΅Ρ ΠΏΠ°ΡΡΠΈΠΈ ΡΠ°Π²Π΅Π½ 4, ΡΡΠ½ΠΊΡΠΈΡ ΠΏΠΎΡΠ΅ΡΡ:
ΠΏΠΎΡΠ΅ΡΡ = F.nll_loss (Π²ΡΡ
ΠΎΠ΄Π½ΡΠ΅ Π΄Π°Π½Π½ΡΠ΅, ΠΌΠ΅ΡΠΊΠΈ)
Π²ΡΡ
ΠΎΠ΄Ρ.Π΄Π°Π½Π½ΡΠ΅:
-2,3220 -2,1229 -2,3395 -2,3391 -2,5270 -2,3269 -2,1055 -2,2321 -2,4943 -2,2996
-2,3653 -2,2034 -2,4437 -2,2708 -2,5114 -2,3286 -2,1921 -2,1771 -2,3343 -2,2533
-2,2809 -2,2119 -2,3872 -2,2190 -2,4610 -2,2946 -2,2053 -2,3192 -2,3674 -2,3100
-2,3715 -2,1455 -2,4199 -2,4177 -2,4565 -2,2812 -2,2467 -2,1144 -2,3321 -2,3009
[torch.FloatTensor ΡΠ°Π·ΠΌΠ΅ΡΠΎΠΌ 4x10]
ΡΡΠΈΠΊΠ΅ΡΠΊΠΈ.Π΄Π°Π½Π½ΡΠ΅:
8
6
0
1
[torch.LongTensor ΡΠ°Π·ΠΌΠ΅ΡΠ° 4]
ΠΌΠ΅ΡΠΊΠΈ Π΄Π»Ρ Π²Ρ ΠΎΠ΄Π½ΠΎΠ³ΠΎ ΠΈΠ·ΠΎΠ±ΡΠ°ΠΆΠ΅Π½ΠΈΡ Π΄ΠΎΠ»ΠΆΠ½Ρ Π±ΡΡΡ ΠΎΠ΄Π½ΠΈΠΌ ΡΠ»Π΅ΠΌΠ΅Π½ΡΠΎΠΌ, Π² Π²Π΅ΡΡ Π½Π΅ΠΌ ΠΏΡΠΈΠΌΠ΅ΡΠ΅ 253 ΡΠΈΡΠ»Π°, Π° Π² Β«mnistΒ» Π΅ΡΡΡ ΡΠΎΠ»ΡΠΊΠΎ ΠΎΠ΄Π½ΠΎ ΡΠΈΡΠ»ΠΎ, ΡΠΎΡΠΌΠ° Π²ΡΡ ΠΎΠ΄ΠΎΠ² ΠΎΡΠ»ΠΈΡΠ°Π΅ΡΡΡ ΠΎΡ ΠΌΠ΅ΡΠΎΠΊ.
Ρ ΠΏΡΠΎΡΠΌΠ°ΡΡΠΈΠ²Π°Ρ ΡΡΠΊΠΎΠ²ΠΎΠ΄ΡΡΠ²ΠΎ ΠΏΠΎ ΡΠ΅Π½Π·ΠΎΡΠ½ΠΎΠΌΡ ΠΏΠΎΡΠΎΠΊΡ, tf.nn.softmax_cross_entropy_with_logits,
Β«ΠΠΎΠ³ΠΈΡΡ ΠΈ ΠΌΠ΅ΡΠΊΠΈ Π΄ΠΎΠ»ΠΆΠ½Ρ ΠΈΠΌΠ΅ΡΡ ΠΎΠ΄ΠΈΠ½Π°ΠΊΠΎΠ²ΡΡ ΡΠΎΡΠΌΡ [batch_size, num_classes] ΠΈ ΠΎΠ΄ΠΈΠ½ ΠΈ ΡΠΎΡ ΠΆΠ΅ ΡΠΈΠΏ dtype (Π»ΠΈΠ±ΠΎ float32, Π»ΠΈΠ±ΠΎ float64)Β».
ΠΡΠ°ΠΊ, ΠΌΠΎΠ³Ρ Π»ΠΈ Ρ ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°ΡΡ pytorch Π² ΡΡΠΎΠΌ ΡΠ»ΡΡΠ°Π΅ ΠΈΠ»ΠΈ ΠΊΠ°ΠΊ ΠΌΠ½Π΅ ΡΡΠΎ ΡΠ΄Π΅Π»Π°ΡΡ?
Π±ΠΎΠ»ΡΡΠΎΠ΅ ΡΠΏΠ°ΡΠΈΠ±ΠΎ
http://pytorch.org/docs/nn.html#crossentropyloss
Π€ΠΎΡΠΌΠ° CrossEntropyLoss:
ΠΡ
ΠΎΠ΄: (N,C), Π³Π΄Π΅ C = ΠΊΠΎΠ»ΠΈΡΠ΅ΡΡΠ²ΠΎ ΠΊΠ»Π°ΡΡΠΎΠ²
Π¦Π΅Π»Ρ: (N), Π³Π΄Π΅ ΠΊΠ°ΠΆΠ΄ΠΎΠ΅ Π·Π½Π°ΡΠ΅Π½ΠΈΠ΅ ΡΠ°Π²Π½ΠΎ 0 <= target[i] <= C-1
Ρ ΠΌΠ΅Π½Ρ Π΅ΡΡΡ Π΄ΡΡΠ³ΠΎΠΉ Π²ΡΠ±ΠΎΡ?
Π― Π±Ρ Π½Π΅ ΡΠΊΠ°Π·Π°Π», ΡΡΠΎ Π΄ΠΎΠΊΡΠΌΠ΅Π½ΡΡ TF β Π»ΡΡΡΠ΅Π΅ ΠΌΠ΅ΡΡΠΎ Π΄Π»Ρ ΠΈΠ·ΡΡΠ΅Π½ΠΈΡ API PyTorch π ΠΡ Π½Π΅ ΠΏΡΡΠ°Π΅ΠΌΡΡ Π±ΡΡΡ ΡΠΎΠ²ΠΌΠ΅ΡΡΠΈΠΌΡΠΌΠΈ Ρ TF, ΠΈ Π½Π°Ρ CrossEntropyLoss ΠΏΡΠΈΠ½ΠΈΠΌΠ°Π΅Ρ Π²Π΅ΠΊΡΠΎΡ ΠΈΠ½Π΄Π΅ΠΊΡΠΎΠ² ΠΊΠ»Π°ΡΡΠΎΠ² (ΡΡΠΎ ΠΏΠΎΠ·Π²ΠΎΠ»ΡΠ΅Ρ Π΅ΠΌΡ ΡΠ°Π±ΠΎΡΠ°ΡΡ Π½Π°ΠΌΠ½ΠΎΠ³ΠΎ Π±ΡΡΡΡΠ΅Π΅, ΡΠ΅ΠΌ Π΅ΡΠ»ΠΈ Π±Ρ ΠΎΠ½ ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π» 1-Π³ΠΎΡΡΡΠΈΠ΅ Π²Π΅ΠΊΡΠΎΡΡ). ΠΡΠ΅ΠΎΠ±ΡΠ°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ ΠΌΠ΅ΠΆΠ΄Ρ ΠΎΠ±ΠΎΠΈΠΌΠΈ ΠΏΡΠ΅Π΄ΡΡΠ°Π²Π»Π΅Π½ΠΈΡΠΌΠΈ Π΄ΠΎΠ»ΠΆΠ½ΠΎ Π±ΡΡΡ ΠΏΡΠΎΡΡΡΠΌ, Π΅ΡΠ»ΠΈ Π²Π°ΠΌ ΡΡΠΎ Π΄Π΅ΠΉΡΡΠ²ΠΈΡΠ΅Π»ΡΠ½ΠΎ Π½ΡΠΆΠ½ΠΎ.
ΠΠ±ΡΠ°ΡΠΈΡΠ΅ Π²Π½ΠΈΠΌΠ°Π½ΠΈΠ΅, ΡΡΠΎ ΠΌΡ ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΠ΅ΠΌ Π²ΠΎΠΏΡΠΎΡΡ GitHub ΡΠΎΠ»ΡΠΊΠΎ Π΄Π»Ρ ΠΎΡΡΠ΅ΡΠΎΠ² ΠΎΠ± ΠΎΡΠΈΠ±ΠΊΠ°Ρ . ΠΡΠ»ΠΈ Ρ Π²Π°Ρ Π΅ΡΡΡ Π²ΠΎΠΏΡΠΎΡΡ, Π·Π°Π΄Π°Π²Π°ΠΉΡΠ΅ ΠΈΡ Π½Π° Π½Π°ΡΠΈΡ ΡΠΎΡΡΠΌΠ°Ρ .
Π‘ΠΏΠ°ΡΠΈΠ±ΠΎ, ΠΏΡΠ΅ΠΎΠ±ΡΠ°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ ΠΌΠ°ΡΡΠΈΡΡ ΠΊΠΎΠ΄ΠΈΡΠΎΠ²Π°Π½ΠΈΡ ΠΎΠ΄Π½ΠΎΠ³ΠΎ Π³ΠΎΡΡΡΠ΅Π³ΠΎ ΠΊΠ»Π°ΡΡΠ° Π² ΡΠ΅Π»ΠΎΡΠΈΡΠ»Π΅Π½Π½ΡΠΉ Π²Π΅ΠΊΡΠΎΡ ΡΡΡΡΠ°Π½ΠΈΠ»ΠΎ ΠΏΡΠΎΠ±Π»Π΅ΠΌΡ CrossEntropyLoss Π΄Π»Ρ ΠΌΠ΅Π½Ρ!
Π‘Π°ΠΌΡΠΉ ΠΏΠΎΠ»Π΅Π·Π½ΡΠΉ ΠΊΠΎΠΌΠΌΠ΅Π½ΡΠ°ΡΠΈΠΉ
Π― Π±Ρ Π½Π΅ ΡΠΊΠ°Π·Π°Π», ΡΡΠΎ Π΄ΠΎΠΊΡΠΌΠ΅Π½ΡΡ TF β Π»ΡΡΡΠ΅Π΅ ΠΌΠ΅ΡΡΠΎ Π΄Π»Ρ ΠΈΠ·ΡΡΠ΅Π½ΠΈΡ API PyTorch π ΠΡ Π½Π΅ ΠΏΡΡΠ°Π΅ΠΌΡΡ Π±ΡΡΡ ΡΠΎΠ²ΠΌΠ΅ΡΡΠΈΠΌΡΠΌΠΈ Ρ TF, ΠΈ Π½Π°Ρ CrossEntropyLoss ΠΏΡΠΈΠ½ΠΈΠΌΠ°Π΅Ρ Π²Π΅ΠΊΡΠΎΡ ΠΈΠ½Π΄Π΅ΠΊΡΠΎΠ² ΠΊΠ»Π°ΡΡΠΎΠ² (ΡΡΠΎ ΠΏΠΎΠ·Π²ΠΎΠ»ΡΠ΅Ρ Π΅ΠΌΡ ΡΠ°Π±ΠΎΡΠ°ΡΡ Π½Π°ΠΌΠ½ΠΎΠ³ΠΎ Π±ΡΡΡΡΠ΅Π΅, ΡΠ΅ΠΌ Π΅ΡΠ»ΠΈ Π±Ρ ΠΎΠ½ ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°Π» 1-Π³ΠΎΡΡΡΠΈΠ΅ Π²Π΅ΠΊΡΠΎΡΡ). ΠΡΠ΅ΠΎΠ±ΡΠ°Π·ΠΎΠ²Π°Π½ΠΈΠ΅ ΠΌΠ΅ΠΆΠ΄Ρ ΠΎΠ±ΠΎΠΈΠΌΠΈ ΠΏΡΠ΅Π΄ΡΡΠ°Π²Π»Π΅Π½ΠΈΡΠΌΠΈ Π΄ΠΎΠ»ΠΆΠ½ΠΎ Π±ΡΡΡ ΠΏΡΠΎΡΡΡΠΌ, Π΅ΡΠ»ΠΈ Π²Π°ΠΌ ΡΡΠΎ Π΄Π΅ΠΉΡΡΠ²ΠΈΡΠ΅Π»ΡΠ½ΠΎ Π½ΡΠΆΠ½ΠΎ.
ΠΠ±ΡΠ°ΡΠΈΡΠ΅ Π²Π½ΠΈΠΌΠ°Π½ΠΈΠ΅, ΡΡΠΎ ΠΌΡ ΠΈΡΠΏΠΎΠ»ΡΠ·ΡΠ΅ΠΌ Π²ΠΎΠΏΡΠΎΡΡ GitHub ΡΠΎΠ»ΡΠΊΠΎ Π΄Π»Ρ ΠΎΡΡΠ΅ΡΠΎΠ² ΠΎΠ± ΠΎΡΠΈΠ±ΠΊΠ°Ρ . ΠΡΠ»ΠΈ Ρ Π²Π°Ρ Π΅ΡΡΡ Π²ΠΎΠΏΡΠΎΡΡ, Π·Π°Π΄Π°Π²Π°ΠΉΡΠ΅ ΠΈΡ Π½Π° Π½Π°ΡΠΈΡ ΡΠΎΡΡΠΌΠ°Ρ .