Numpy: argsort donne de mauvais résultats

Créé le 8 mars 2017  ·  4Commentaires  ·  Source: numpy/numpy

La fonction argsort semble être cassée. En regardant le code fourni, l'argsort pour les lignes [0, 1] est correct mais il est foiré pour les lignes [2, 3].
J'ai testé cela sur différentes installations NumPy et versions 1.11.0 et 1.12.0

import numpy as np

vec = np.array([
    [-1.4, -1.2,  1.3],
    [-3.6,  3.9, -3.7],
    [-2.3,  1.5, -2. ],
    [-2.6,  2.4, -1.6]
    ])

In [1]: np.argsort(-vec, axis=1)
Out[1]: 
array([[2, 1, 0],
       [1, 0, 2],
       [1, 2, 0],
       [1, 2, 0]])
53 - Invalid

Commentaire le plus utile

Ceci est le deuxième résultat Google pour "np argsort erroné".

L'explication sur la page docs n'est pas claire (pour moi). Je vais ajouter ma propre explication ici dans l'espoir que cela aide quelqu'un:

x = numpy.array([1.48,1.31,0.0,0.8])
print x.argsort()

>[2 3 1 0]

Certaines personnes pourraient s'attendre à ce que cela donne à la place [3, 2, 0, 1] , c'est-à-dire que le 0e élément du tableau non trié devrait être le 3e élément du tableau trié.

En fait, il fournit des indices tels que x[np.argsort(x)] vous donnera une liste triée, c'est-à-dire [0.0, 0.8, 1.31, 1.48] . Autrement dit, [2 3 1 0] vous indique que le 0e élément du tableau trié est le 2e élément du tableau non trié.

Si vous voulez vraiment obtenir [3, 2, 0, 1] en sortie, vous pouvez à la place faire

np.argsort(np.argsort(x))
>[3 2 0 1]

Alternativement, si vous êtes vraiment hors de propos comme moi et que vous ne voulez, disons, que les indices des 3 plus grands éléments de x :

np.argsort(x)[:-4:-1]
>[0, 1, 3]

Tous les 4 commentaires

Je ne vois rien de mal avec le résultat, imprimez vec[np.arange(4)[:, np.newaxis], np.argsort(-vec, axis=1)] et voyez que ça a l'air bien.

Ceci est le deuxième résultat Google pour "np argsort erroné".

L'explication sur la page docs n'est pas claire (pour moi). Je vais ajouter ma propre explication ici dans l'espoir que cela aide quelqu'un:

x = numpy.array([1.48,1.31,0.0,0.8])
print x.argsort()

>[2 3 1 0]

Certaines personnes pourraient s'attendre à ce que cela donne à la place [3, 2, 0, 1] , c'est-à-dire que le 0e élément du tableau non trié devrait être le 3e élément du tableau trié.

En fait, il fournit des indices tels que x[np.argsort(x)] vous donnera une liste triée, c'est-à-dire [0.0, 0.8, 1.31, 1.48] . Autrement dit, [2 3 1 0] vous indique que le 0e élément du tableau trié est le 2e élément du tableau non trié.

Si vous voulez vraiment obtenir [3, 2, 0, 1] en sortie, vous pouvez à la place faire

np.argsort(np.argsort(x))
>[3 2 0 1]

Alternativement, si vous êtes vraiment hors de propos comme moi et que vous ne voulez, disons, que les indices des 3 plus grands éléments de x :

np.argsort(x)[:-4:-1]
>[0, 1, 3]

Si vous voulez vraiment obtenir [3, 2, 0, 1] en sortie, vous pouvez à la place faire

Cela sera plus rapide :

a = np.empty(len(x), np.intp)
a[np.argsort(x)] = np.arange(len(x))

9880 suggère d'ajouter ceci à numpy comme np.invert_permutation(np.argsort(x))

@rossbar , @bjnath : Peut-être vaut-il la peine d'extraire des éléments du commentaire de @ghost ci-dessus et de les mettre dans la documentation ? J'ai ajouté quelques liens croisés supplémentaires pour montrer plus d'exemples de confusion.

Cette page vous a été utile?
0 / 5 - 0 notes