Pytorch: [PyTorch][๊ธฐ๋Šฅ ์š”์ฒญ] CrossEntropyLoss์— ๋Œ€ํ•œ ๋ ˆ์ด๋ธ” ํ‰ํ™œํ™”

์— ๋งŒ๋“  2018๋…„ 05์›” 10์ผ  ยท  22์ฝ”๋ฉ˜ํŠธ  ยท  ์ถœ์ฒ˜: pytorch/pytorch

์•ˆ๋…•, ์–˜๋“ค์•„. target ์˜ torch.LongTensor ์œ ํ˜•์€ reference ์˜ ์ผ๋ถ€ ๋ฉ”์†Œ๋“œ์™€ ๊ฐ™์ด ๊ตฌํ˜„์„ ๋ฐฉํ•ดํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ Arg: label_smoothing ๋Œ€ํ•ด torch.nn.CrossEntropyLoss() ๋ฅผ ์ถ”๊ฐ€ํ•˜๊ฑฐ๋‚˜ target ๋ฅผ one-hot vector ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ์ž‘์—…ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ฃผ๋Š” ๋ฌธ์„œ๋ฅผ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๊นŒ? torch.nn.CrossEntropyLoss() ํ•จ๊ป˜ ๋˜๋Š” ๋‹ค๋ฅธ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•? ๊ฐ์‚ฌ ํ•ด์š”.

cc @ezyang @gchanan @zou3519 @bdhirsh @albanD @mruberry

enhancement high priority loss nn triage review triaged

๊ฐ€์žฅ ์œ ์šฉํ•œ ๋Œ“๊ธ€

์—ฌ๊ธฐ ๋‚ด ๋„๊ตฌ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค

class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

๋ชจ๋“  22 ๋Œ“๊ธ€

@KaiyuYue
label_smoothing์˜ ๊ฒฝ์šฐ NJUNMT-pytorch ๊ตฌํ˜„์„

์ˆ˜์—…์—์„œ NMTCritierion

https://github.com/whr94621/NJUNMT-pytorch/blob/aff968c0da9273dc42eabbb8ac4e459f9195f6e4/src/modules/criterions.py#L131

https://discuss.pytorch.org/t/cross-entropy-with-one-hot-targets/13580/5๋ฅผ ์ฐธ์กฐ cross_entropy() ํ•จ์ˆ˜๋Š” ๋„คํŠธ์›Œํฌ ์ถœ๋ ฅ๊ณผ ๋™์ผํ•œ ์ฐจ์›์„ ๊ฐ–๋Š” ํ‰ํ™œํ™”๋œ ๋ ˆ์ด๋ธ”๊ณผ ํ•จ๊ป˜ ์ž‘๋™ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๋ ˆ์ด๋ธ” ์Šค๋ฌด๋”ฉ์€ ๋‹ค์–‘ํ•œ ๋ฐฉ๋ฒ•์œผ๋กœ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๊ณ  ์Šค๋ฌด๋”ฉ ์ž์ฒด๋Š” ์‚ฌ์šฉ์ž๊ฐ€ ์ˆ˜๋™์œผ๋กœ ์‰ฝ๊ฒŒ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— CrossEntropyLoss() ๊ฐ€ label_smoothing ์˜ต์…˜์„ ์ง์ ‘ ์ง€์›ํ•ด์•ผ ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์Šค์นผ๋ผ ๊ฐ’์œผ๋กœ ํ‘œํ˜„ํ•  ์ˆ˜ ์—†๋Š” ๋Œ€์ƒ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ๋ฐฉ๋ฒ•์ด๋‚˜ CrossEntropyLoss (k-hot/smoothed) ๋Œ€์ƒ ์ „๋‹ฌ์— ๋Œ€ํ•œ ์ง€์›์„ ์ถ”๊ฐ€ํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ๋ฌธ์„œ์—์„œ ์ตœ์†Œํ•œ ์–ธ๊ธ‰ํ•ด์•ผ ํ•œ๋‹ค๋Š” ๋ฐ ๋™์˜ํ•ฉ๋‹ˆ๋‹ค.

์–ด์ฉŒ๋ฉด NonSparseCrossEntropy ์™€ ๊ฐ™์€ sth๊ฐ€ ํ•„์š”ํ• ๊นŒ์š”? (๊ธ€์Ž„.. ์ด๋ฆ„์„ ์ง“๊ธฐ ์–ด๋ ต๋‹ค)

์—ฌ๊ธฐ ๋‚ด ๋„๊ตฌ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค

class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

@mdraw์— ๋™์˜ํ•ฉ๋‹ˆ๋‹ค.
์ข‹์€ ์„ ํƒ์€ ๋‘ ๋‹จ๊ณ„๋กœ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

  1. ๊ธฐ๋Šฅ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ถ€๋“œ๋Ÿฌ์šด ๋ ˆ์ด๋ธ” ์–ป๊ธฐ
def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0):
    """
    if smoothing == 0, it's one-hot method
    if 0 < smoothing < 1, it's smooth method

    """
    assert 0 <= smoothing < 1
    confidence = 1.0 - smoothing
    label_shape = torch.Size((true_labels.size(0), classes))
    with torch.no_grad():
        true_dist = torch.empty(size=label_shape, device=true_labels.device)
        true_dist.fill_(smoothing / (classes - 1))
        true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
    return true_dist
  1. CrossEntropyLoss k-hot/smoothed ๋Œ€์ƒ์„ ์ง€์›ํ•˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

๊ทธ๋Ÿฌ๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Loss = CrossEntropyLoss(NonSparse=True, ...)
. . .
data = ...
labels = ...

outputs = model(data)

smooth_label = smooth_one_hot(labels, ...)
loss = (outputs, smooth_label)
...

๊ทธ๋Ÿฐ๋ฐ ImageNet์—์„œ ๋‚ด ๊ตฌํ˜„์„ ํ…Œ์ŠคํŠธํ–ˆ๋Š”๋ฐ ์ข‹์•„ ๋ณด์ž…๋‹ˆ๋‹ค.

|๋ชจ๋ธ | ์‹œ๋Œ€| dtype |๋ฐฐ์น˜ ํฌ๊ธฐ*|gpus | lr | ํŠธ๋ฆญ|top1/top5 |๊ฐœ์„  |
|:----:|:-----:|:-----:|:---------:|:----:|:---:|: ------:|:---------:|:------:|
|resnet50|120 |FP16 |128 | 8 |0.4 | - |77.35/- |๊ธฐ์ค€|
|resnet50|120 |FP16 |128 | 8 |0.4 |๋ ˆ์ด๋ธ” ์Šค๋ฌด๋”ฉ|77.78/93.80| +0.43 |

@zhangguanheng66 ์ด ์ด๊ฒƒ์ด ๊ทธ๊ฐ€ ๋ฏธ๋ž˜์— ๋ณผ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋ผ๊ณ  ๋งํ•œ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.

๊ทธ๋ƒฅ torch.nn.KLDivLoss๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”. ๊ทธ๊ฒƒ์€ ๋™์ผํ•ฉ๋‹ˆ๋‹ค.


์—…๋ฐ์ดํŠธ: ๋™์ผํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

๋‚˜๋Š” ์ด๊ฒƒ์ด ์ƒˆ๋กœ์šด Snorkel lib๊ฐ€ ๊ตฌํ˜„ํ•œ ๊ฒƒ๊ณผ ์œ ์‚ฌํ•˜๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค.
https://snorkel.readthedocs.io/en/master/packages/_autosummary/classification/snorkel.classification.cross_entropy_with_probs.html

์‚ฌ๋žŒ๋“ค์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•œ ๋ช‡ ๊ฐ€์ง€ ์ถ”๊ฐ€ ์ •๋ณด

Nvidia๊ฐ€ ๋„์›€์ด ๋  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์€ https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5 ๋ฅผ ์ฐธ์กฐ

@suanrong ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค.

====
๊ทธ๋ฆฌ๊ณ  ์•„๋งˆ๋„ ์ด๊ฒƒ์€ ์ด ๋ฌธ์ œ๋ฅผ ์ฝ๋Š” ๋‹ค๋ฅธ ์‚ฌ๋žŒ๋“ค์—๊ฒŒ ๋„์›€์ด ๋  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

0/1์ด ์•„๋‹Œ ๋ ˆ์ด๋ธ”์— ๋Œ€ํ•œ ๊ต์ฐจ ์—”ํŠธ๋กœํ”ผ๋Š” ๋Œ€์นญ์ด ์•„๋‹ˆ๋ฏ€๋กœ ์„ฑ๋Šฅ ์ €ํ•˜์— ๋Œ€ํ•œ ์„ค๋ช…์ด ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
https://discuss.pytorch.org/t/cross-entropy-for-soft-label/16093/2

์ œ์•ˆ๋œ ๊ตฌํ˜„:

class LabelSmoothLoss(nn.Module):

    def __init__(self, smoothing=0.0):
        super(LabelSmoothLoss, self).__init__()
        self.smoothing = smoothing

    def forward(self, input, target):
        log_prob = F.log_softmax(input, dim=-1)
        weight = input.new_ones(input.size()) * \
            self.smoothing / (input.size(-1) - 1.)
        weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))
        loss = (-weight * log_prob).sum(dim=-1).mean()
        return loss

๋‚˜๋Š” ๊ทธ๊ฒƒ์„ ํ™•์ธํ–ˆ๋‹ค :
(1) smoothing=0.0์ผ ๋•Œ 1e-5 ์ •๋ฐ€๋„ ๋‚ด์—์„œ nn.CrossEntropyLoss ์™€ ๋™์ผํ•˜๊ฒŒ ์ถœ๋ ฅ๋ฉ๋‹ˆ๋‹ค.
(2) ์Šค๋ฌด๋”ฉ>0.0์ผ ๋•Œ ์„œ๋กœ ๋‹ค๋ฅธ ํด๋ž˜์Šค weight.sum(dim=-1) ๋Œ€ํ•œ ๊ฐ€์ค‘์น˜์˜ ํ•ฉ์€ ํ•ญ์ƒ 1์ž…๋‹ˆ๋‹ค.

์—ฌ๊ธฐ ๊ตฌํ˜„์—๋Š” ํด๋ž˜์Šค ๊ฐ€์ค‘์น˜ ๊ธฐ๋Šฅ์ด ์—†์Šต๋‹ˆ๋‹ค.
((

๊ทธ๋ƒฅ torch.nn.KLDivLoss๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”. ๊ทธ๊ฒƒ์€ ๋™์ผํ•ฉ๋‹ˆ๋‹ค.

๋” ์ž์„ธํžˆ ์„ค๋ช…ํ•ด ์ฃผ์‹œ๊ฒ ์Šต๋‹ˆ๊นŒ

๊ทธ๋ƒฅ torch.nn.KLDivLoss๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”. ๊ทธ๊ฒƒ์€ ๋™์ผํ•ฉ๋‹ˆ๋‹ค.

๋” ์ž์„ธํžˆ ์„ค๋ช…ํ•ด ์ฃผ์‹œ๊ฒ ์Šต๋‹ˆ๊นŒ

์ด๋ฏธ ์Šค๋ฌด๋”ฉ๋œ ๋ ˆ์ด๋ธ”์ด ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•˜๋ฉด ๋‘ ๊ฐ’์˜ ์ฐจ์ด๊ฐ€ ๋ ˆ์ด๋ธ”์˜ ์—”ํŠธ๋กœํ”ผ์ด๊ณ  ์ƒ์ˆ˜์ด๊ธฐ ๋•Œ๋ฌธ์— torch.nn.KLDivLoss๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

@PistonY ์™œ ์ด๋ ‡๊ฒŒ ๊ฐ„๋‹จํ•˜๊ฒŒ ์‚ฌ์šฉํ•˜์ง€

with torch.no_grad():
    confidence = 1.0 - smoothing_factor
    true_dist = torch.mul(labels, confidence)
    true_dist = torch.add(true_dist, smoothing_factor / (classNum - 1))
    print(true_dist)
return true_dist

์—ฌ๊ธฐ ๊ตฌํ˜„์—๋Š” ํด๋ž˜์Šค ๊ฐ€์ค‘์น˜ ๊ธฐ๋Šฅ์ด ์—†์Šต๋‹ˆ๋‹ค.

ํ‰ํ™œ ๋ ˆ์ด๋ธ” ํ…์„œ์— ํด๋ž˜์Šค ๊ฐ€์ค‘์น˜๋ฅผ ๊ณฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๊นŒ?

def smooth_one_hot(true_labels: torch.Tensor, ํด๋ž˜์Šค: int, smoothing=0.0):
""
ํ‰ํ™œํ™” == 0์ด๋ฉด ์›-ํ•ซ ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.
0 < ์Šค๋ฌด๋”ฉ < 1์ด๋ฉด ๋ถ€๋“œ๋Ÿฌ์šด ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.

"""
assert 0 <= smoothing < 1
confidence = 1.0 - smoothing
label_shape = torch.Size((true_labels.size(0), classes))
with torch.no_grad():
    true_dist = torch.empty(size=label_shape, device=true_labels.device)
    true_dist.fill_(smoothing / (classes - 1))
    true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
return true_dist

```

์ด ๊ตฌํ˜„์˜ ๋ฌธ์ œ๋Š” ํด๋ž˜์Šค ์ˆ˜์— ๋งค์šฐ ๋ฏผ๊ฐํ•˜๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

n_classes๊ฐ€ 2์ธ ๊ฒฝ์šฐ 0.5๋ฅผ ์ดˆ๊ณผํ•˜๋Š” ์Šค๋ฌด๋”ฉ์€ ๋ ˆ์ด๋ธ”์„ ๋’ค์ง‘ ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ์‚ฌ์šฉ์ž๊ฐ€ ์›ํ•˜์ง€ ์•Š์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค. n_classes๊ฐ€ 3์ด๋ฉด 2/3์„ ์ดˆ๊ณผํ•˜๋Š” ์Šค๋ฌด๋”ฉ์ด๊ณ  4๊ฐœ ํด๋ž˜์Šค์— ๋Œ€ํ•ด 0.75์ž…๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ์•„๋งˆ๋„:

assert 0 <= smoothing < (classes-1)/classes ์ด ๋ฌธ์ œ๋ฅผ ์žก์„ ์ˆ˜ ์žˆ์ง€๋งŒ ํ‰ํ™œํ™”์— ํด๋ž˜์Šค ์ˆ˜๋ฅผ ๊ณ ๋ คํ•ด์•ผ ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๊นŒ?

def smooth_one_hot(true_labels: torch.Tensor, ํด๋ž˜์Šค: int, smoothing=0.0):

"""
if smoothing == 0, it's one-hot method
if 0 < smoothing < 1, it's smooth method

"""
assert 0 <= smoothing < 1
confidence = 1.0 - smoothing
label_shape = torch.Size((true_labels.size(0), classes))
with torch.no_grad():
    true_dist = torch.empty(size=label_shape, device=true_labels.device)
    true_dist.fill_(smoothing / (classes - 1))
    true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
return true_dist

```

์ด ๊ตฌํ˜„์˜ ๋ฌธ์ œ๋Š” ํด๋ž˜์Šค ์ˆ˜์— ๋งค์šฐ ๋ฏผ๊ฐํ•˜๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

n_classes๊ฐ€ 2์ธ ๊ฒฝ์šฐ 0.5๋ฅผ ์ดˆ๊ณผํ•˜๋Š” ์Šค๋ฌด๋”ฉ์€ ๋ ˆ์ด๋ธ”์„ ๋’ค์ง‘ ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ์‚ฌ์šฉ์ž๊ฐ€ ์›ํ•˜์ง€ ์•Š์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค. n_classes๊ฐ€ 3์ด๋ฉด 2/3์„ ์ดˆ๊ณผํ•˜๋Š” ์Šค๋ฌด๋”ฉ์ด๊ณ  4๊ฐœ ํด๋ž˜์Šค์— ๋Œ€ํ•ด 0.75์ž…๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ์•„๋งˆ๋„:

assert 0 <= smoothing < (classes-1)/classes ์ด ๋ฌธ์ œ๋ฅผ ์žก์„ ์ˆ˜ ์žˆ์ง€๋งŒ ํ‰ํ™œํ™”์— ํด๋ž˜์Šค ์ˆ˜๋ฅผ ๊ณ ๋ คํ•ด์•ผ ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๊นŒ?

ํ˜„๋ช…ํ•œ ์ƒ๊ฐ์ž…๋‹ˆ๋‹ค.

ํ† ๋ก ํ•ด์ฃผ์…”์„œ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค. ๋ถˆ๋ถ„๋ช…ํ•˜๊ณ  ์‹ค์ˆ˜์ฒ˜๋Ÿผ ๋ณด์ด๋Š” ๋ช‡ ๊ฐ€์ง€ ์‚ฌํ•ญ์ด ์žˆ์Šต๋‹ˆ๋‹ค.

  • @PistonY ๊ตฌํ˜„์˜ ๊ฐ€์ค‘์น˜ ํ…์„œ
  • KL ๋ฐœ์‚ฐ๊ณผ ๋ ˆ์ด๋ธ” ํ‰ํ™œํ™” ๊ฐ„์˜ ๋™๋“ฑ์„ฑ( @suanrong )

๋ฌด๊ฒŒ ์ •๋ณด:

๋ ˆ์ด๋ธ” ํ‰ํ™œ์šฉ์ง€๋Š” y_k = smoothing / n_classes + (1 - smoothing) * y_{one hot} ์ž…๋‹ˆ๋‹ค. ๊ฐ€์ค‘์น˜์˜ ๊ฐ’์ด๋˜๋„๋ก smoothing / n_classes ๋Œ€์ƒ ์ด์™ธ์˜ ์ง€ํ‘œ์— ๋Œ€ํ•œ, ๊ทธ๋ฆฌ๊ณ  ์ธ smoothing / n_classes + (1 - smoothing) ๋Œ€์ƒ ํด๋ž˜์Šค. ๊ทธ๋Ÿฌ๋‚˜ @PistonY ์˜ ๊ตฌํ˜„์—์„œ torch.scatter_ ํ•จ์ˆ˜๋Š” ๋Œ€์ƒ ๊ฐ’์„ (1 - smoothing) ๋ฎ์–ด์”๋‹ˆ๋‹ค(๊ทธ๋ฆฌ๊ณ  ์ƒ์ˆ˜ ์šฉ์–ด๋Š” ์‚ฌ๋ผ์ง‘๋‹ˆ๋‹ค).
๊ฒŒ๋‹ค๊ฐ€ ๊ณ„์‚ฐ(?)์— n_classes -= 1 ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ์ด์œ ๋ฅผ ์ž˜ ๋ชจ๋ฅด๊ฒ ์Šต๋‹ˆ๋‹ค

KL ๋ฐœ์‚ฐ๊ณผ ๋ ˆ์ด๋ธ” ํ‰ํ™œํ™” ๊ฐ„์˜ ๋™๋“ฑ์„ฑ ์ •๋ณด:

๋ ˆ์ด๋ธ” ํ‰ํ™œํ™” ๊ต์ฐจ ์—”ํŠธ๋กœํ”ผ ์†์‹ค์€ ์œ„์—์„œ ์–ธ๊ธ‰ํ•œ y ๊ฐ€์ค‘์น˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ฝ์Šต๋‹ˆ๋‹ค.

LS(x, y) = - sum_k {y[k] * log-prob(x)}
         = - sum_k {y[k] * log(exp(x[k]) / (sum_j exp(x[j])))}
         = - sum_k {y[k] * (x[k] - log-sum-exp(x))}
         = - sum_k {y[k] * x[k]} + log-sum-exp(x)

์—ฌ๊ธฐ์„œ ์„ธ ๋ฒˆ์งธ์—์„œ ๋„ค ๋ฒˆ์งธ ์ค„์€ sum_k y[k] = smoothing / n_classes * n_classes + (1 - smoothing) = 1 ๋ผ๋Š” ์‚ฌ์‹ค์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

KL ๋ฐœ์‚ฐ ์†์‹ค์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

KL(x, y) = - sum_k {y[k] * x[k] - y[k] * log(y[k])
         = - sum_k {y[k] * x[k]} - sum_k {y[k] * log(y[k])}
         = - sum_k {y[k] * x[k]} - Const.

๋”ฐ๋ผ์„œ ๊ฒฐ๊ตญ LS(x, y) = KL(x, y) + log-sum-exp(x) + Const. . ์—ฌ๊ธฐ์„œ Const. ๋Š” y ์˜ ์—”ํŠธ๋กœํ”ผ์— ํ•ด๋‹นํ•˜๋Š” ์ƒ์ˆ˜ ํ•ญ์ด๋ฉฐ, ์ด๋Š” ์‹ค์ œ๋กœ ๋‹ค์ค‘ ํด๋ž˜์Šค ์„ค์ •์—์„œ ์ผ์ •ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ log-sum-exp ์šฉ์–ด๋Š” ์–ด๋–ป์Šต๋‹ˆ๊นŒ?

์†Œํ”„ํŠธ ํƒ€๊ฒŸ์„ ํ—ˆ์šฉ ํ•˜๋Š” KLDiv loss + log-sum-exp ์™€ ๋™์ผํ•˜๋‹ค๋Š” ๊ฒƒ์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. y . ์ด ์šฉ์–ด๋ฅผ ์‚ญ์ œํ•˜๋Š” ๊ฒƒ์ด ํ•ฉ๋ฆฌ์ ์ด๊ฒŒ ํ•˜๋Š” ๋กœ์ง“์— ๋Œ€ํ•œ ๊ฐ€์ •์ด ์žˆ์Šต๋‹ˆ๊นŒ?

๋งŽ์€ ์„ค๋ช… ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค.
๊ฑด๋ฐฐ !

@antrec ๊ฐ์‚ฌ

๋‹น์‹ ์ด ๋งž์Šต๋‹ˆ๋‹ค. logsoftmax ํ•จ์ˆ˜๋ฅผ ๋ฌด์‹œํ•˜๊ณ  ์‹ค์ˆ˜๋ฅผ ํ–ˆ์Šต๋‹ˆ๋‹ค.

๋ ˆ์ด๋ธ” ์Šค๋ฌด๋”ฉ ๊ต์ฐจ ์—”ํŠธ๋กœํ”ผ ์†์‹ค ํ•จ์ˆ˜์˜ ๊ตฌํ˜„:

import torch.nn.functional as F
def linear_combination(x, y, epsilon): 
    return epsilon*x + (1-epsilon)*y

def reduce_loss(loss, reduction='mean'):
    return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, epsilon:float=0.1, reduction='mean'):
        super().__init__()
        self.epsilon = epsilon
        self.reduction = reduction

    def forward(self, preds, target):
        n = preds.size()[-1]
        log_preds = F.log_softmax(preds, dim=-1)
        loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
        nll = F.nll_loss(log_preds, target, reduction=self.reduction)
        return linear_combination(loss/n, nll, self.epsilon)

ํ™œ๋™์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•˜์ดํ”„๋ฆฌ๋กœ ์ด๋™

์ด ํŽ˜์ด์ง€๊ฐ€ ๋„์›€์ด ๋˜์—ˆ๋‚˜์š”?
0 / 5 - 0 ๋“ฑ๊ธ‰