Numpy: argsort memberikan hasil yang salah

Dibuat pada 8 Mar 2017  ·  4Komentar  ·  Sumber: numpy/numpy

Fungsi argsort tampaknya rusak. Melihat kode yang diberikan, argsort untuk baris [0, 1] benar tetapi kacau untuk baris [2, 3].
Saya menguji ini pada instalasi NumPy yang berbeda dan versi 1.11.0 dan 1.12.0

import numpy as np

vec = np.array([
    [-1.4, -1.2,  1.3],
    [-3.6,  3.9, -3.7],
    [-2.3,  1.5, -2. ],
    [-2.6,  2.4, -1.6]
    ])

In [1]: np.argsort(-vec, axis=1)
Out[1]: 
array([[2, 1, 0],
       [1, 0, 2],
       [1, 2, 0],
       [1, 2, 0]])
53 - Invalid

Komentar yang paling membantu

Ini adalah hasil google kedua untuk "np argsort salah".

Penjelasan di halaman dokumen tidak jelas (bagi saya). Saya akan menambahkan penjelasan saya sendiri di sini dengan harapan dapat membantu seseorang:

x = numpy.array([1.48,1.31,0.0,0.8])
print x.argsort()

>[2 3 1 0]

Beberapa orang mungkin mengharapkan ini untuk memberikan [3, 2, 0, 1] , yaitu, elemen ke-0 dalam array yang tidak disortir harus menjadi elemen ke-3 dalam array yang diurutkan.

Apa yang sebenarnya dilakukannya adalah memberikan indeks sedemikian rupa sehingga x[np.argsort(x)] akan memberi Anda daftar yang diurutkan, yaitu [0.0, 0.8, 1.31, 1.48] . Dengan kata lain, [2 3 1 0] memberitahu Anda bahwa elemen ke-0 dari array yang diurutkan adalah elemen ke-2 dari array yang tidak disortir.

Jika Anda benar-benar ingin mendapatkan [3, 2, 0, 1] sebagai output, Anda dapat melakukannya

np.argsort(np.argsort(x))
>[3 2 0 1]

Atau, jika Anda benar-benar berada di luar basis seperti saya dulu dan hanya ingin, katakanlah, indeks dari 3 elemen terbesar di x :

np.argsort(x)[:-4:-1]
>[0, 1, 3]

Semua 4 komentar

Tidak ada yang salah dengan hasilnya, cetak vec[np.arange(4)[:, np.newaxis], np.argsort(-vec, axis=1)] dan lihat hasilnya bagus.

Ini adalah hasil google kedua untuk "np argsort salah".

Penjelasan di halaman dokumen tidak jelas (bagi saya). Saya akan menambahkan penjelasan saya sendiri di sini dengan harapan dapat membantu seseorang:

x = numpy.array([1.48,1.31,0.0,0.8])
print x.argsort()

>[2 3 1 0]

Beberapa orang mungkin mengharapkan ini untuk memberikan [3, 2, 0, 1] , yaitu, elemen ke-0 dalam array yang tidak disortir harus menjadi elemen ke-3 dalam array yang diurutkan.

Apa yang sebenarnya dilakukannya adalah memberikan indeks sedemikian rupa sehingga x[np.argsort(x)] akan memberi Anda daftar yang diurutkan, yaitu [0.0, 0.8, 1.31, 1.48] . Dengan kata lain, [2 3 1 0] memberitahu Anda bahwa elemen ke-0 dari array yang diurutkan adalah elemen ke-2 dari array yang tidak disortir.

Jika Anda benar-benar ingin mendapatkan [3, 2, 0, 1] sebagai output, Anda dapat melakukannya

np.argsort(np.argsort(x))
>[3 2 0 1]

Atau, jika Anda benar-benar berada di luar basis seperti saya dulu dan hanya ingin, katakanlah, indeks dari 3 elemen terbesar di x :

np.argsort(x)[:-4:-1]
>[0, 1, 3]

Jika Anda benar-benar ingin mendapatkan [3, 2, 0, 1] sebagai output, Anda dapat melakukannya

Melakukan ini akan lebih cepat:

a = np.empty(len(x), np.intp)
a[np.argsort(x)] = np.arange(len(x))

9880 menyarankan menambahkan ini ke numpy sebagai np.invert_permutation(np.argsort(x))

@rossbar , @bjnath : Mungkin perlu mengekstrak hal-hal dari komentar @ghost di atas dan memasukkannya ke dalam dokumen? Saya telah menambahkan beberapa tautan silang lagi untuk menunjukkan lebih banyak contoh kebingungan.

Apakah halaman ini membantu?
0 / 5 - 0 peringkat