Pytorch: [PyTorch][Feature Request] Label Smoothing for CrossEntropyLoss

Created on 10 May 2018  ·  22Comments  ·  Source: pytorch/pytorch

Hi, guys. The type torch.LongTensor of target will hinder the implementation like some methods in reference. So is there a possible to add a Arg: label_smoothing for torch.nn.CrossEntropyLoss(), or maybe simply add the docs to show how to convert the target into one-hot vector to work with torch.nn.CrossEntropyLoss() together, or any other simple ways? Thanks.

cc @ezyang @gchanan @zou3519 @bdhirsh @albanD @mruberry

enhancement high priority loss nn triage review triaged

Most helpful comment

Here is my implement

class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

All 22 comments

@KaiyuYue
For label_smoothing, you cat look at the implementation of NJUNMT-pytorch

In the class NMTCritierion

https://github.com/whr94621/NJUNMT-pytorch/blob/aff968c0da9273dc42eabbb8ac4e459f9195f6e4/src/modules/criterions.py#L131

See https://discuss.pytorch.org/t/cross-entropy-with-one-hot-targets/13580/5. The cross_entropy() function that's shown there should work with smoothed labels that have the same dimension as the network outputs.

I don't think CrossEntropyLoss() should directly support a label_smoothing option, since label smoothing can be done in many different ways and the smoothing itself can be easily done manually by the user. But I agree it should at least be mentioned in the docs how to deal with targets that can't be represented by scalar values, or add support for passing (k-hot/smoothed) targets to CrossEntropyLoss.

Maybe we need sth like NonSparseCrossEntropy? (well.. it's hard to name it)

Here is my implement

class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

I agree with @mdraw
A good choice is do it in two step:

  1. Use a function to get smooth label
def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0):
    """
    if smoothing == 0, it's one-hot method
    if 0 < smoothing < 1, it's smooth method

    """
    assert 0 <= smoothing < 1
    confidence = 1.0 - smoothing
    label_shape = torch.Size((true_labels.size(0), classes))
    with torch.no_grad():
        true_dist = torch.empty(size=label_shape, device=true_labels.device)
        true_dist.fill_(smoothing / (classes - 1))
        true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
    return true_dist
  1. Make CrossEntropyLoss support k-hot/smoothed targets.

Then we can use it like

Loss = CrossEntropyLoss(NonSparse=True, ...)
. . .
data = ...
labels = ...

outputs = model(data)

smooth_label = smooth_one_hot(labels, ...)
loss = (outputs, smooth_label)
...

By the way I tested my implement on ImageNet, it looks good

|model | epochs| dtype |batch size*|gpus | lr | tricks|top1/top5 |improve |
|:----:|:-----:|:-----:|:---------:|:----:|:---:|:------:|:---------:|:------:|
|resnet50|120 |FP16 |128 | 8 |0.4 | - |77.35/- |baseline|
|resnet50|120 |FP16 |128 | 8 |0.4 |Lable smoothing|77.78/93.80| +0.43 |

I believe @zhangguanheng66 said that this is something he might be able to look at in the future.

Just use torch.nn.KLDivLoss. It is the same.


Update: it is not the same.

I believe this is similar to what the new Snorkel lib implemented:
https://snorkel.readthedocs.io/en/master/packages/_autosummary/classification/snorkel.classification.cross_entropy_with_probs.html

Just some extra info on how people are going around the issue

@suanrong Thanks a lot.

====
And maybe this is helpful for others who read this issue

Note that cross-entropy for non 0/1 labels is not symmetric, which could be an explanation for the poor performance.
https://discuss.pytorch.org/t/cross-entropy-for-soft-label/16093/2

Suggested implementation:

class LabelSmoothLoss(nn.Module):

    def __init__(self, smoothing=0.0):
        super(LabelSmoothLoss, self).__init__()
        self.smoothing = smoothing

    def forward(self, input, target):
        log_prob = F.log_softmax(input, dim=-1)
        weight = input.new_ones(input.size()) * \
            self.smoothing / (input.size(-1) - 1.)
        weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))
        loss = (-weight * log_prob).sum(dim=-1).mean()
        return loss

I have checked that:
(1) When smoothing=0.0, the output is the same as nn.CrossEntropyLoss within precision 1e-5.
(2) When smoothing>0.0, the sums of weights over different classes weight.sum(dim=-1) are always 1.

Implementations here lack of class weights feature.
((

Just use torch.nn.KLDivLoss. It is the same.

can you please elaborate more

Just use torch.nn.KLDivLoss. It is the same.

can you please elaborate more

Assumed you already have smoothed label, you can just use torch.nn.KLDivLoss since the difference between them is entropy of the label and is a constant.

@PistonY why not use this way much simple:

with torch.no_grad():
    confidence = 1.0 - smoothing_factor
    true_dist = torch.mul(labels, confidence)
    true_dist = torch.add(true_dist, smoothing_factor / (classNum - 1))
    print(true_dist)
return true_dist

Implementations here lack of class weights feature.

Can i multiply the class weights on the smoothed label tensor?

def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0):
"""
if smoothing == 0, it's one-hot method
if 0 < smoothing < 1, it's smooth method

"""
assert 0 <= smoothing < 1
confidence = 1.0 - smoothing
label_shape = torch.Size((true_labels.size(0), classes))
with torch.no_grad():
    true_dist = torch.empty(size=label_shape, device=true_labels.device)
    true_dist.fill_(smoothing / (classes - 1))
    true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
return true_dist

```

The problem with this implementation is it's very sensitive to the number of classes

Where n_classes is 2, any smoothing above 0.5 will reverse the labels, which I'm sure the person does not want; when n_classes is 3 it's any smoothing above 2/3, and 0.75 for 4 classes. So maybe:

assert 0 <= smoothing < (classes-1)/classes would catch this issue, but I feel the smoothing needs to take the number of classes into account?

def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0):

"""
if smoothing == 0, it's one-hot method
if 0 < smoothing < 1, it's smooth method

"""
assert 0 <= smoothing < 1
confidence = 1.0 - smoothing
label_shape = torch.Size((true_labels.size(0), classes))
with torch.no_grad():
    true_dist = torch.empty(size=label_shape, device=true_labels.device)
    true_dist.fill_(smoothing / (classes - 1))
    true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
return true_dist

```

The problem with this implementation is it's very sensitive to the number of classes

Where n_classes is 2, any smoothing above 0.5 will reverse the labels, which I'm sure the person does not want; when n_classes is 3 it's any smoothing above 2/3, and 0.75 for 4 classes. So maybe:

assert 0 <= smoothing < (classes-1)/classes would catch this issue, but I feel the smoothing needs to take the number of classes into account?

It's a wise idea I think.

Thanks for the discussion. There are a few points that remain unclear and look like mistakes to me:

  • the weight tensor in @PistonY 's implementation
  • the equivalence between KL divergence and label-smoothing (@suanrong )

About the weights:

The label smoothing paper states y_k = smoothing / n_classes + (1 - smoothing) * y_{one hot}. So the value of the weight is smoothing / n_classes for indices other than the target, and it is smoothing / n_classes + (1 - smoothing) for the target class. Yet in @PistonY 's implementation, the function torch.scatter_ overwrites the value for the target to (1 - smoothing) (and the constant term disappears).
Moreover, I do not really understand why we use n_classes -= 1 in the computation (?)

About the equivalence between KL divergence and label-smoothing:

The label-smoothing cross-entropy loss reads, with y the weights mentioned above,

LS(x, y) = - sum_k {y[k] * log-prob(x)}
         = - sum_k {y[k] * log(exp(x[k]) / (sum_j exp(x[j])))}
         = - sum_k {y[k] * (x[k] - log-sum-exp(x))}
         = - sum_k {y[k] * x[k]} + log-sum-exp(x)

where the third to the fourth line uses the fact that sum_k y[k] = smoothing / n_classes * n_classes + (1 - smoothing) = 1.

The KL-divergence loss reads,

KL(x, y) = - sum_k {y[k] * x[k] - y[k] * log(y[k])
         = - sum_k {y[k] * x[k]} - sum_k {y[k] * log(y[k])}
         = - sum_k {y[k] * x[k]} - Const.

So in the end we have LS(x, y) = KL(x, y) + log-sum-exp(x) + Const., where Const. is the constant term corresponding to the entropy of y, which is indeed constant in multiclass settings. But what about the log-sum-exp term ?

I did a few computations using a custom cross entropy function accepting soft targets, and it shows that it is indeed equal to the KLDiv loss plus log-sum-exp, up to the constant term corresponding to the entropy of y. Is there any assumption on the logits that make it reasonable to drop this term ?

Thanks a lot for the clarifications.
Cheers !

Thanks @antrec !

You are right. I ignored the logsoftmax function and made a mistake.

The implementation of a label smoothing cross-entropy loss function:

import torch.nn.functional as F
def linear_combination(x, y, epsilon): 
    return epsilon*x + (1-epsilon)*y

def reduce_loss(loss, reduction='mean'):
    return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, epsilon:float=0.1, reduction='mean'):
        super().__init__()
        self.epsilon = epsilon
        self.reduction = reduction

    def forward(self, preds, target):
        n = preds.size()[-1]
        log_preds = F.log_softmax(preds, dim=-1)
        loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
        nll = F.nll_loss(log_preds, target, reduction=self.reduction)
        return linear_combination(loss/n, nll, self.epsilon)

Bumping to hi-pri based on activity

Was this page helpful?
0 / 5 - 0 ratings

Related issues

Coderx7 picture Coderx7  ·  3Comments

NgPDat picture NgPDat  ·  3Comments

bartolsthoorn picture bartolsthoorn  ·  3Comments

SeparateReality picture SeparateReality  ·  3Comments

bartvm picture bartvm  ·  3Comments