Pytorch: [PyTorch][Feature Request] Lissage des étiquettes pour CrossEntropyLoss

Créé le 10 mai 2018  ·  22Commentaires  ·  Source: pytorch/pytorch

Salut les gars. Le type torch.LongTensor de target gênera l'implémentation comme certaines méthodes de reference . Alors, est-il possible d'ajouter un Arg: label_smoothing pour torch.nn.CrossEntropyLoss() , ou peut-être simplement d'ajouter la documentation pour montrer comment convertir le target en one-hot vector avec lequel travailler torch.nn.CrossEntropyLoss() ensemble, ou tout autre moyen simple ? Merci.

cc @ezyang @gchanan @zou3519 @bdhirsh @albanD @mruberry

enhancement high priority loss nn triage review triaged

Commentaire le plus utile

voici mon outil

class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

Tous les 22 commentaires

@KaiyuYue
Pour label_smoothing, regardez l'implémentation de NJUNMT-pytorch

Dans la classe NMTCritierion

https://github.com/whr94621/NJUNMT-pytorch/blob/aff968c0da9273dc42eabbb8ac4e459f9195f6e4/src/modules/criterions.py#L131

Voir https://discuss.pytorch.org/t/cross-entropy-with-one-hot-targets/13580/5. La fonction cross_entropy() qui y est montrée devrait fonctionner avec des étiquettes lissées qui ont la même dimension que les sorties du réseau.

Je ne pense pas que CrossEntropyLoss() devrait prendre en charge directement une option label_smoothing , car le lissage des étiquettes peut être effectué de différentes manières et le lissage lui-même peut être facilement effectué manuellement par l'utilisateur. Mais je suis d'accord qu'il devrait au moins être mentionné dans la documentation comment traiter les cibles qui ne peuvent pas être représentées par des valeurs scalaires, ou ajouter la prise en charge du passage des cibles (k-hot/lissé) à CrossEntropyLoss .

Peut-être qu'on a besoin de qc comme NonSparseCrossEntropy ? (enfin.. c'est difficile de le nommer)

voici mon outil

class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

Je suis d'accord avec @mdraw
Un bon choix est de le faire en deux étapes :

  1. Utiliser une fonction pour obtenir une étiquette lisse
def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0):
    """
    if smoothing == 0, it's one-hot method
    if 0 < smoothing < 1, it's smooth method

    """
    assert 0 <= smoothing < 1
    confidence = 1.0 - smoothing
    label_shape = torch.Size((true_labels.size(0), classes))
    with torch.no_grad():
        true_dist = torch.empty(size=label_shape, device=true_labels.device)
        true_dist.fill_(smoothing / (classes - 1))
        true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
    return true_dist
  1. Faites en sorte que CrossEntropyLoss prenne en charge les cibles k-hot/lissées.

Ensuite, nous pouvons l'utiliser comme

Loss = CrossEntropyLoss(NonSparse=True, ...)
. . .
data = ...
labels = ...

outputs = model(data)

smooth_label = smooth_one_hot(labels, ...)
loss = (outputs, smooth_label)
...

Au fait, j'ai testé mon outil sur ImageNet, ça a l'air bien

|modèle | époques| dtype |taille du lot*|gpus | gd | astuces|top1/top5 |améliorer |
|:----:|:-----:|:-----:|:---------:|:----:|:---:|: ------:|:---------:|:------:|
|resnet50|120 |FP16 |128 | 8 |0.4 | - |77,35/- |base|
|resnet50|120 |FP16 |128 | 8 |0.4 |Lissage des étiquettes|77.78/93.80| +0.43 |

Je crois que @ zhangguanheng66 a dit que c'était quelque chose qu'il pourrait envisager à l'avenir.

Utilisez simplement torch.nn.KLDivLoss. C'est le même.


Mise à jour : ce n'est pas la même chose.

Je pense que cela est similaire à ce que la nouvelle bibliothèque Snorkel a implémenté :
https://snorkel.readthedocs.io/en/master/packages/_autosummary/classification/snorkel.classification.cross_entropy_with_probs.html

Juste quelques informations supplémentaires sur la façon dont les gens contournent le problème

voir https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5 pour savoir comment Nvidia le fait qui pourrait aider ?

@suanrong Merci beaucoup.

====
Et peut-être que cela sera utile pour ceux qui liront ce numéro

Notez que l'entropie croisée pour les étiquettes non 0/1 n'est pas symétrique, ce qui pourrait expliquer les mauvaises performances.
https://discuss.pytorch.org/t/cross-entropy-for-soft-label/16093/2

Mise en œuvre suggérée :

class LabelSmoothLoss(nn.Module):

    def __init__(self, smoothing=0.0):
        super(LabelSmoothLoss, self).__init__()
        self.smoothing = smoothing

    def forward(self, input, target):
        log_prob = F.log_softmax(input, dim=-1)
        weight = input.new_ones(input.size()) * \
            self.smoothing / (input.size(-1) - 1.)
        weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))
        loss = (-weight * log_prob).sum(dim=-1).mean()
        return loss

J'ai vérifié que :
(1) Lorsque lissage=0.0, la sortie est la même que nn.CrossEntropyLoss avec une précision de 1e-5 .
(2) Lors du lissage>0,0, les sommes des poids sur différentes classes weight.sum(dim=-1) sont toujours 1.

Les implémentations ici manquent de fonctionnalité de pondération des classes.
((

Utilisez simplement torch.nn.KLDivLoss. C'est le même.

pouvez-vous s'il vous plaît élaborer plus

Utilisez simplement torch.nn.KLDivLoss. C'est le même.

pouvez-vous s'il vous plaît élaborer plus

En supposant que vous ayez déjà une étiquette lissée, vous pouvez simplement utiliser torch.nn.KLDivLoss car la différence entre elles est l'entropie de l'étiquette et est une constante.

@PistonY pourquoi ne pas utiliser cette manière très simple :

with torch.no_grad():
    confidence = 1.0 - smoothing_factor
    true_dist = torch.mul(labels, confidence)
    true_dist = torch.add(true_dist, smoothing_factor / (classNum - 1))
    print(true_dist)
return true_dist

Les implémentations ici manquent de fonctionnalité de pondération des classes.

Puis-je multiplier les poids de classe sur le tenseur d'étiquette lissé ?

def smooth_one_hot(true_labels : torch.Tensor, classes : int, smoothing=0.0) :
"""
si lissage == 0, c'est la méthode one-hot
si 0 < lissage < 1, c'est la méthode de lissage

"""
assert 0 <= smoothing < 1
confidence = 1.0 - smoothing
label_shape = torch.Size((true_labels.size(0), classes))
with torch.no_grad():
    true_dist = torch.empty(size=label_shape, device=true_labels.device)
    true_dist.fill_(smoothing / (classes - 1))
    true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
return true_dist

```

Le problème avec cette implémentation est qu'elle est très sensible au nombre de classes

Où n_classes est égal à 2, tout lissage au-dessus de 0,5 inversera les étiquettes, ce que la personne ne souhaite certainement pas ; lorsque n_classes est égal à 3, il s'agit de tout lissage au-dessus de 2/3 et de 0,75 pour 4 classes. Alors peut-être:

assert 0 <= smoothing < (classes-1)/classes permettrait de résoudre ce problème, mais je pense que le lissage doit prendre en compte le nombre de classes ?

def smooth_one_hot(true_labels : torch.Tensor, classes : int, smoothing=0.0) :

"""
if smoothing == 0, it's one-hot method
if 0 < smoothing < 1, it's smooth method

"""
assert 0 <= smoothing < 1
confidence = 1.0 - smoothing
label_shape = torch.Size((true_labels.size(0), classes))
with torch.no_grad():
    true_dist = torch.empty(size=label_shape, device=true_labels.device)
    true_dist.fill_(smoothing / (classes - 1))
    true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
return true_dist

```

Le problème avec cette implémentation est qu'elle est très sensible au nombre de classes

Où n_classes est égal à 2, tout lissage au-dessus de 0,5 inversera les étiquettes, ce que la personne ne souhaite certainement pas ; lorsque n_classes est égal à 3, il s'agit de tout lissage au-dessus de 2/3 et de 0,75 pour 4 classes. Alors peut-être:

assert 0 <= smoothing < (classes-1)/classes permettrait de résoudre ce problème, mais je pense que le lissage doit prendre en compte le nombre de classes ?

C'est une sage idée je pense.

Merci pour le débat. Il y a quelques points qui restent flous et me semblent être des erreurs :

  • le tenseur de poids dans l' implémentation de
  • l'équivalence entre divergence KL et lissage d'étiquette ( @suanrong )

A propos des poids :

L'étiquette du papier de lissage indique y_k = smoothing / n_classes + (1 - smoothing) * y_{one hot} . Ainsi, la valeur du poids est de smoothing / n_classes pour les indices autres que la cible, et elle est de smoothing / n_classes + (1 - smoothing) pour la classe cible. Pourtant, dans l' implémentation de torch.scatter_ remplace la valeur de la cible par (1 - smoothing) (et le terme constant disparaît).
De plus, je ne comprends pas vraiment pourquoi on utilise n_classes -= 1 dans le calcul (?)

À propos de l'équivalence entre la divergence KL et le lissage des étiquettes :

La perte d'entropie croisée de lissage des étiquettes se lit, avec y les poids mentionnés ci-dessus,

LS(x, y) = - sum_k {y[k] * log-prob(x)}
         = - sum_k {y[k] * log(exp(x[k]) / (sum_j exp(x[j])))}
         = - sum_k {y[k] * (x[k] - log-sum-exp(x))}
         = - sum_k {y[k] * x[k]} + log-sum-exp(x)

où la troisième à la quatrième ligne utilise le fait que sum_k y[k] = smoothing / n_classes * n_classes + (1 - smoothing) = 1 .

La perte de divergence KL lit,

KL(x, y) = - sum_k {y[k] * x[k] - y[k] * log(y[k])
         = - sum_k {y[k] * x[k]} - sum_k {y[k] * log(y[k])}
         = - sum_k {y[k] * x[k]} - Const.

Nous avons donc au final LS(x, y) = KL(x, y) + log-sum-exp(x) + Const. , où Const. est le terme constant correspondant à l'entropie de y , qui est en effet constante dans les paramètres multiclasses. Mais qu'en est-il du terme log-sum-exp ?

J'ai fait quelques calculs en utilisant une fonction d'entropie croisée personnalisée acceptant les cibles souples , et cela montre qu'elle est bien égale à la perte de KLDiv plus log-sum-exp , jusqu'au terme constant correspondant à l'entropie de y . Y a-t-il une hypothèse sur les logits qui rend raisonnable l'abandon de ce terme ?

Merci beaucoup pour les éclaircissements.
À votre santé !

Merci @antrec !

Vous avez raison. J'ai ignoré la fonction logsoftmax et j'ai fait une erreur.

La mise en œuvre d'une fonction de perte d'entropie croisée par lissage d'étiquettes :

import torch.nn.functional as F
def linear_combination(x, y, epsilon): 
    return epsilon*x + (1-epsilon)*y

def reduce_loss(loss, reduction='mean'):
    return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, epsilon:float=0.1, reduction='mean'):
        super().__init__()
        self.epsilon = epsilon
        self.reduction = reduction

    def forward(self, preds, target):
        n = preds.size()[-1]
        log_preds = F.log_softmax(preds, dim=-1)
        loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
        nll = F.nll_loss(log_preds, target, reduction=self.reduction)
        return linear_combination(loss/n, nll, self.epsilon)

Passage à hi-pri en fonction de l'activité

Cette page vous a été utile?
0 / 5 - 0 notes