์๋
, ์๋ค์. target ์ torch.LongTensor
์ ํ์ reference ์ ์ผ๋ถ ๋ฉ์๋์ ๊ฐ์ด ๊ตฌํ์ ๋ฐฉํดํฉ๋๋ค. ๋ฐ๋ผ์ Arg: label_smoothing
๋ํด torch.nn.CrossEntropyLoss()
๋ฅผ ์ถ๊ฐํ๊ฑฐ๋ target
๋ฅผ one-hot vector
๋ก ๋ณํํ์ฌ ์์
ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ฃผ๋ ๋ฌธ์๋ฅผ ์ถ๊ฐํ ์ ์์ต๋๊น? torch.nn.CrossEntropyLoss()
ํจ๊ป ๋๋ ๋ค๋ฅธ ๊ฐ๋จํ ๋ฐฉ๋ฒ? ๊ฐ์ฌ ํด์.
cc @ezyang @gchanan @zou3519 @bdhirsh @albanD @mruberry
@KaiyuYue
label_smoothing์ ๊ฒฝ์ฐ NJUNMT-pytorch ๊ตฌํ์
์์
์์ NMTCritierion
https://discuss.pytorch.org/t/cross-entropy-with-one-hot-targets/13580/5๋ฅผ ์ฐธ์กฐ cross_entropy()
ํจ์๋ ๋คํธ์ํฌ ์ถ๋ ฅ๊ณผ ๋์ผํ ์ฐจ์์ ๊ฐ๋ ํํํ๋ ๋ ์ด๋ธ๊ณผ ํจ๊ป ์๋ํด์ผ ํฉ๋๋ค.
๋ ์ด๋ธ ์ค๋ฌด๋ฉ์ ๋ค์ํ ๋ฐฉ๋ฒ์ผ๋ก ์ํํ ์ ์๊ณ ์ค๋ฌด๋ฉ ์์ฒด๋ ์ฌ์ฉ์๊ฐ ์๋์ผ๋ก ์ฝ๊ฒ ์ํํ ์ ์๊ธฐ ๋๋ฌธ์ CrossEntropyLoss()
๊ฐ label_smoothing
์ต์
์ ์ง์ ์ง์ํด์ผ ํ๋ค๊ณ ์๊ฐํ์ง ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ค์นผ๋ผ ๊ฐ์ผ๋ก ํํํ ์ ์๋ ๋์์ ์ฒ๋ฆฌํ๋ ๋ฐฉ๋ฒ์ด๋ CrossEntropyLoss
(k-hot/smoothed) ๋์ ์ ๋ฌ์ ๋ํ ์ง์์ ์ถ๊ฐํ๋ ๋ฐฉ๋ฒ์ ๋ํด ๋ฌธ์์์ ์ต์ํ ์ธ๊ธํด์ผ ํ๋ค๋ ๋ฐ ๋์ํฉ๋๋ค.
์ด์ฉ๋ฉด NonSparseCrossEntropy
์ ๊ฐ์ sth๊ฐ ํ์ํ ๊น์? (๊ธ์.. ์ด๋ฆ์ ์ง๊ธฐ ์ด๋ ต๋ค)
์ฌ๊ธฐ ๋ด ๋๊ตฌ๊ฐ ์์ต๋๋ค
class LabelSmoothingLoss(nn.Module):
def __init__(self, classes, smoothing=0.0, dim=-1):
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.cls = classes
self.dim = dim
def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim)
with torch.no_grad():
# true_dist = pred.data.clone()
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (self.cls - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
@mdraw์ ๋์ํฉ๋๋ค.
์ข์ ์ ํ์ ๋ ๋จ๊ณ๋ก ์ํํ๋ ๊ฒ์
๋๋ค.
def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0):
"""
if smoothing == 0, it's one-hot method
if 0 < smoothing < 1, it's smooth method
"""
assert 0 <= smoothing < 1
confidence = 1.0 - smoothing
label_shape = torch.Size((true_labels.size(0), classes))
with torch.no_grad():
true_dist = torch.empty(size=label_shape, device=true_labels.device)
true_dist.fill_(smoothing / (classes - 1))
true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
return true_dist
CrossEntropyLoss
k-hot/smoothed ๋์์ ์ง์ํ๋๋ก ํฉ๋๋ค.๊ทธ๋ฌ๋ฉด ๋ค์๊ณผ ๊ฐ์ด ์ฌ์ฉํ ์ ์์ต๋๋ค.
Loss = CrossEntropyLoss(NonSparse=True, ...)
. . .
data = ...
labels = ...
outputs = model(data)
smooth_label = smooth_one_hot(labels, ...)
loss = (outputs, smooth_label)
...
๊ทธ๋ฐ๋ฐ ImageNet์์ ๋ด ๊ตฌํ์ ํ ์คํธํ๋๋ฐ ์ข์ ๋ณด์ ๋๋ค.
|๋ชจ๋ธ | ์๋| dtype |๋ฐฐ์น ํฌ๊ธฐ*|gpus | lr | ํธ๋ฆญ|top1/top5 |๊ฐ์ |
|:----:|:-----:|:-----:|:---------:|:----:|:---:|: ------:|:---------:|:------:|
|resnet50|120 |FP16 |128 | 8 |0.4 | - |77.35/- |๊ธฐ์ค|
|resnet50|120 |FP16 |128 | 8 |0.4 |๋ ์ด๋ธ ์ค๋ฌด๋ฉ|77.78/93.80| +0.43 |
@zhangguanheng66 ์ด ์ด๊ฒ์ด ๊ทธ๊ฐ ๋ฏธ๋์ ๋ณผ ์ ์์ ๊ฒ์ด๋ผ๊ณ ๋งํ ๊ฒ ๊ฐ์ต๋๋ค.
๊ทธ๋ฅ torch.nn.KLDivLoss๋ฅผ ์ฌ์ฉํ์ธ์. ๊ทธ๊ฒ์ ๋์ผํฉ๋๋ค.
์ ๋ฐ์ดํธ: ๋์ผํ์ง ์์ต๋๋ค.
๋๋ ์ด๊ฒ์ด ์๋ก์ด Snorkel lib๊ฐ ๊ตฌํํ ๊ฒ๊ณผ ์ ์ฌํ๋ค๊ณ ์๊ฐํฉ๋๋ค.
https://snorkel.readthedocs.io/en/master/packages/_autosummary/classification/snorkel.classification.cross_entropy_with_probs.html
์ฌ๋๋ค์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๋ฐฉ๋ฒ์ ๋ํ ๋ช ๊ฐ์ง ์ถ๊ฐ ์ ๋ณด
Nvidia๊ฐ ๋์์ด ๋ ์ ์๋ ๋ฐฉ๋ฒ์ https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5 ๋ฅผ ์ฐธ์กฐ
@suanrong ๊ฐ์ฌํฉ๋๋ค.
====
๊ทธ๋ฆฌ๊ณ ์๋ง๋ ์ด๊ฒ์ ์ด ๋ฌธ์ ๋ฅผ ์ฝ๋ ๋ค๋ฅธ ์ฌ๋๋ค์๊ฒ ๋์์ด ๋ ๊ฒ์
๋๋ค.
0/1์ด ์๋ ๋ ์ด๋ธ์ ๋ํ ๊ต์ฐจ ์ํธ๋กํผ๋ ๋์นญ์ด ์๋๋ฏ๋ก ์ฑ๋ฅ ์ ํ์ ๋ํ ์ค๋ช ์ด ๋ ์ ์์ต๋๋ค.
https://discuss.pytorch.org/t/cross-entropy-for-soft-label/16093/2
์ ์๋ ๊ตฌํ:
class LabelSmoothLoss(nn.Module):
def __init__(self, smoothing=0.0):
super(LabelSmoothLoss, self).__init__()
self.smoothing = smoothing
def forward(self, input, target):
log_prob = F.log_softmax(input, dim=-1)
weight = input.new_ones(input.size()) * \
self.smoothing / (input.size(-1) - 1.)
weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))
loss = (-weight * log_prob).sum(dim=-1).mean()
return loss
๋๋ ๊ทธ๊ฒ์ ํ์ธํ๋ค :
(1) smoothing=0.0์ผ ๋ 1e-5
์ ๋ฐ๋ ๋ด์์ nn.CrossEntropyLoss
์ ๋์ผํ๊ฒ ์ถ๋ ฅ๋ฉ๋๋ค.
(2) ์ค๋ฌด๋ฉ>0.0์ผ ๋ ์๋ก ๋ค๋ฅธ ํด๋์ค weight.sum(dim=-1)
๋ํ ๊ฐ์ค์น์ ํฉ์ ํญ์ 1์
๋๋ค.
์ฌ๊ธฐ ๊ตฌํ์๋ ํด๋์ค ๊ฐ์ค์น ๊ธฐ๋ฅ์ด ์์ต๋๋ค.
((
๊ทธ๋ฅ torch.nn.KLDivLoss๋ฅผ ์ฌ์ฉํ์ธ์. ๊ทธ๊ฒ์ ๋์ผํฉ๋๋ค.
๋ ์์ธํ ์ค๋ช ํด ์ฃผ์๊ฒ ์ต๋๊น
๊ทธ๋ฅ torch.nn.KLDivLoss๋ฅผ ์ฌ์ฉํ์ธ์. ๊ทธ๊ฒ์ ๋์ผํฉ๋๋ค.
๋ ์์ธํ ์ค๋ช ํด ์ฃผ์๊ฒ ์ต๋๊น
์ด๋ฏธ ์ค๋ฌด๋ฉ๋ ๋ ์ด๋ธ์ด ์๋ค๊ณ ๊ฐ์ ํ๋ฉด ๋ ๊ฐ์ ์ฐจ์ด๊ฐ ๋ ์ด๋ธ์ ์ํธ๋กํผ์ด๊ณ ์์์ด๊ธฐ ๋๋ฌธ์ torch.nn.KLDivLoss๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.
@PistonY ์ ์ด๋ ๊ฒ ๊ฐ๋จํ๊ฒ ์ฌ์ฉํ์ง
with torch.no_grad():
confidence = 1.0 - smoothing_factor
true_dist = torch.mul(labels, confidence)
true_dist = torch.add(true_dist, smoothing_factor / (classNum - 1))
print(true_dist)
return true_dist
์ฌ๊ธฐ ๊ตฌํ์๋ ํด๋์ค ๊ฐ์ค์น ๊ธฐ๋ฅ์ด ์์ต๋๋ค.
ํํ ๋ ์ด๋ธ ํ ์์ ํด๋์ค ๊ฐ์ค์น๋ฅผ ๊ณฑํ ์ ์์ต๋๊น?
def smooth_one_hot(true_labels: torch.Tensor, ํด๋์ค: int, smoothing=0.0):
""
ํํํ == 0์ด๋ฉด ์-ํซ ๋ฐฉ๋ฒ์ ๋๋ค.
0 < ์ค๋ฌด๋ฉ < 1์ด๋ฉด ๋ถ๋๋ฌ์ด ๋ฐฉ๋ฒ์ ๋๋ค.""" assert 0 <= smoothing < 1 confidence = 1.0 - smoothing label_shape = torch.Size((true_labels.size(0), classes)) with torch.no_grad(): true_dist = torch.empty(size=label_shape, device=true_labels.device) true_dist.fill_(smoothing / (classes - 1)) true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence) return true_dist
```
์ด ๊ตฌํ์ ๋ฌธ์ ๋ ํด๋์ค ์์ ๋งค์ฐ ๋ฏผ๊ฐํ๋ค๋ ๊ฒ์ ๋๋ค.
n_classes๊ฐ 2์ธ ๊ฒฝ์ฐ 0.5๋ฅผ ์ด๊ณผํ๋ ์ค๋ฌด๋ฉ์ ๋ ์ด๋ธ์ ๋ค์ง ์ต๋๋ค. ์ด๋ ์ฌ์ฉ์๊ฐ ์ํ์ง ์์ ๊ฒ์ ๋๋ค. n_classes๊ฐ 3์ด๋ฉด 2/3์ ์ด๊ณผํ๋ ์ค๋ฌด๋ฉ์ด๊ณ 4๊ฐ ํด๋์ค์ ๋ํด 0.75์ ๋๋ค. ๊ทธ๋์ ์๋ง๋:
assert 0 <= smoothing < (classes-1)/classes
์ด ๋ฌธ์ ๋ฅผ ์ก์ ์ ์์ง๋ง ํํํ์ ํด๋์ค ์๋ฅผ ๊ณ ๋ คํด์ผ ํ๋ค๊ณ ์๊ฐํฉ๋๊น?
def smooth_one_hot(true_labels: torch.Tensor, ํด๋์ค: int, smoothing=0.0):
"""
if smoothing == 0, it's one-hot method
if 0 < smoothing < 1, it's smooth method
"""
assert 0 <= smoothing < 1
confidence = 1.0 - smoothing
label_shape = torch.Size((true_labels.size(0), classes))
with torch.no_grad():
true_dist = torch.empty(size=label_shape, device=true_labels.device)
true_dist.fill_(smoothing / (classes - 1))
true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
return true_dist
```
์ด ๊ตฌํ์ ๋ฌธ์ ๋ ํด๋์ค ์์ ๋งค์ฐ ๋ฏผ๊ฐํ๋ค๋ ๊ฒ์ ๋๋ค.
n_classes๊ฐ 2์ธ ๊ฒฝ์ฐ 0.5๋ฅผ ์ด๊ณผํ๋ ์ค๋ฌด๋ฉ์ ๋ ์ด๋ธ์ ๋ค์ง ์ต๋๋ค. ์ด๋ ์ฌ์ฉ์๊ฐ ์ํ์ง ์์ ๊ฒ์ ๋๋ค. n_classes๊ฐ 3์ด๋ฉด 2/3์ ์ด๊ณผํ๋ ์ค๋ฌด๋ฉ์ด๊ณ 4๊ฐ ํด๋์ค์ ๋ํด 0.75์ ๋๋ค. ๊ทธ๋์ ์๋ง๋:
assert 0 <= smoothing < (classes-1)/classes
์ด ๋ฌธ์ ๋ฅผ ์ก์ ์ ์์ง๋ง ํํํ์ ํด๋์ค ์๋ฅผ ๊ณ ๋ คํด์ผ ํ๋ค๊ณ ์๊ฐํฉ๋๊น?
ํ๋ช ํ ์๊ฐ์ ๋๋ค.
ํ ๋ก ํด์ฃผ์ ์ ๊ฐ์ฌํฉ๋๋ค. ๋ถ๋ถ๋ช ํ๊ณ ์ค์์ฒ๋ผ ๋ณด์ด๋ ๋ช ๊ฐ์ง ์ฌํญ์ด ์์ต๋๋ค.
๋ ์ด๋ธ ํํ์ฉ์ง๋ y_k = smoothing / n_classes + (1 - smoothing) * y_{one hot}
์
๋๋ค. ๊ฐ์ค์น์ ๊ฐ์ด๋๋๋ก smoothing / n_classes
๋์ ์ด์ธ์ ์งํ์ ๋ํ, ๊ทธ๋ฆฌ๊ณ ์ธ smoothing / n_classes + (1 - smoothing)
๋์ ํด๋์ค. ๊ทธ๋ฌ๋ @PistonY ์ ๊ตฌํ์์ torch.scatter_
ํจ์๋ ๋์ ๊ฐ์ (1 - smoothing)
๋ฎ์ด์๋๋ค(๊ทธ๋ฆฌ๊ณ ์์ ์ฉ์ด๋ ์ฌ๋ผ์ง๋๋ค).
๊ฒ๋ค๊ฐ ๊ณ์ฐ(?)์ n_classes -= 1
๋ฅผ ์ฌ์ฉํ๋ ์ด์ ๋ฅผ ์ ๋ชจ๋ฅด๊ฒ ์ต๋๋ค
๋ ์ด๋ธ ํํํ ๊ต์ฐจ ์ํธ๋กํผ ์์ค์ ์์์ ์ธ๊ธํ y
๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ์ฌ ์ฝ์ต๋๋ค.
LS(x, y) = - sum_k {y[k] * log-prob(x)}
= - sum_k {y[k] * log(exp(x[k]) / (sum_j exp(x[j])))}
= - sum_k {y[k] * (x[k] - log-sum-exp(x))}
= - sum_k {y[k] * x[k]} + log-sum-exp(x)
์ฌ๊ธฐ์ ์ธ ๋ฒ์งธ์์ ๋ค ๋ฒ์งธ ์ค์ sum_k y[k] = smoothing / n_classes * n_classes + (1 - smoothing) = 1
๋ผ๋ ์ฌ์ค์ ์ฌ์ฉํฉ๋๋ค.
KL ๋ฐ์ฐ ์์ค์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
KL(x, y) = - sum_k {y[k] * x[k] - y[k] * log(y[k])
= - sum_k {y[k] * x[k]} - sum_k {y[k] * log(y[k])}
= - sum_k {y[k] * x[k]} - Const.
๋ฐ๋ผ์ ๊ฒฐ๊ตญ LS(x, y) = KL(x, y) + log-sum-exp(x) + Const.
. ์ฌ๊ธฐ์ Const.
๋ y
์ ์ํธ๋กํผ์ ํด๋นํ๋ ์์ ํญ์ด๋ฉฐ, ์ด๋ ์ค์ ๋ก ๋ค์ค ํด๋์ค ์ค์ ์์ ์ผ์ ํฉ๋๋ค. ๊ทธ๋ฌ๋ log-sum-exp ์ฉ์ด๋ ์ด๋ป์ต๋๊น?
์ํํธ ํ๊ฒ์ ํ์ฉ ํ๋ KLDiv
loss + log-sum-exp
์ ๋์ผํ๋ค๋ ๊ฒ์ ๋ณด์ฌ์ค๋๋ค. y
. ์ด ์ฉ์ด๋ฅผ ์ญ์ ํ๋ ๊ฒ์ด ํฉ๋ฆฌ์ ์ด๊ฒ ํ๋ ๋ก์ง์ ๋ํ ๊ฐ์ ์ด ์์ต๋๊น?
๋ง์ ์ค๋ช
๊ฐ์ฌํฉ๋๋ค.
๊ฑด๋ฐฐ !
@antrec ๊ฐ์ฌ
๋น์ ์ด ๋ง์ต๋๋ค. logsoftmax ํจ์๋ฅผ ๋ฌด์ํ๊ณ ์ค์๋ฅผ ํ์ต๋๋ค.
๋ ์ด๋ธ ์ค๋ฌด๋ฉ ๊ต์ฐจ ์ํธ๋กํผ ์์ค ํจ์์ ๊ตฌํ:
import torch.nn.functional as F
def linear_combination(x, y, epsilon):
return epsilon*x + (1-epsilon)*y
def reduce_loss(loss, reduction='mean'):
return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, epsilon:float=0.1, reduction='mean'):
super().__init__()
self.epsilon = epsilon
self.reduction = reduction
def forward(self, preds, target):
n = preds.size()[-1]
log_preds = F.log_softmax(preds, dim=-1)
loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
nll = F.nll_loss(log_preds, target, reduction=self.reduction)
return linear_combination(loss/n, nll, self.epsilon)
ํ๋์ ๊ธฐ๋ฐ์ผ๋ก ํ์ดํ๋ฆฌ๋ก ์ด๋
๊ฐ์ฅ ์ ์ฉํ ๋๊ธ
์ฌ๊ธฐ ๋ด ๋๊ตฌ๊ฐ ์์ต๋๋ค