Numpy: 多维布尔掩码的广播错误

创建于 2019-04-03  ·  3评论  ·  资料来源: numpy/numpy

尝试使用两个一维布尔掩码(形状 N 和 M)对形状为 [N, M] 的二维数组进行索引时,某些 True 和 False 组合会导致广播错误(尤其是当其中一个全部为假时)。 我不确定这种行为是否在意料之中,但这似乎非常令人惊讶和不受欢迎。

在下面的示例中, x[[False, True, True], [True, True, True]]错误,而x[[False, True, True], True]x[[False, True, True]]具有预期的行为。

重现代码示例:

import numpy as np
from itertools import product

x = np.zeros((3,3))
mask_1d = [*product([True, False], repeat=3)]

for row_mask, col_mask in product(mask_1d, mask_1d):
    try:
        x[row_mask, col_mask]
    except IndexError as e:
        print(row_mask, col_mask)
        print(e)

错误信息:

     (True, True, True) (True, True, False)
shape mismatch: indexing arrays could not be broadcast together with shapes (3,) (2,) 
(True, True, True) (True, False, True)
shape mismatch: indexing arrays could not be broadcast together with shapes (3,) (2,) 
(True, True, True) (False, True, True)
shape mismatch: indexing arrays could not be broadcast together with shapes (3,) (2,) 
(True, True, True) (False, False, False)
shape mismatch: indexing arrays could not be broadcast together with shapes (3,) (0,) 
(True, True, False) (True, True, True)
shape mismatch: indexing arrays could not be broadcast together with shapes (2,) (3,) 
(True, True, False) (False, False, False)
shape mismatch: indexing arrays could not be broadcast together with shapes (2,) (0,) 
(True, False, True) (True, True, True)
shape mismatch: indexing arrays could not be broadcast together with shapes (2,) (3,) 
(True, False, True) (False, False, False)
shape mismatch: indexing arrays could not be broadcast together with shapes (2,) (0,) 
(False, True, True) (True, True, True)
shape mismatch: indexing arrays could not be broadcast together with shapes (2,) (3,) 
(False, True, True) (False, False, False)
shape mismatch: indexing arrays could not be broadcast together with shapes (2,) (0,) 
(False, False, False) (True, True, True)
shape mismatch: indexing arrays could not be broadcast together with shapes (0,) (3,) 
(False, False, False) (True, True, False)
shape mismatch: indexing arrays could not be broadcast together with shapes (0,) (2,) 
(False, False, False) (True, False, True)
shape mismatch: indexing arrays could not be broadcast together with shapes (0,) (2,) 
(False, False, False) (False, True, True)
shape mismatch: indexing arrays could not be broadcast together with shapes (0,) (2,)

Numpy/Python 版本信息:

1.16.2 3.6.8 |Anaconda, Inc.| (default, Dec 30 2018, 01:22:34) 
[GCC 7.3.0]

最有用的评论

对于布尔数组,代码假定您正在尝试同时索引单个维度所有元素 - 不幸的是,这种选择以允许删除单个True的方式猜测。 即,它将您的row_mask, col_mask转换为 (2,3) 布尔数组,然后发现它无法索引 (3,3) 数组。

部分问题在于元组和列表被视为等价的,这是我们试图摆脱的。 最终,您将通过确保掩码是双列表来处理布尔数组索引。

不过,就目前而言,我担心唯一的解决方案是执行x[row_mask][:, col_mask]

cc @eric-wieser,他一直致力于弃用索引操作的“将元组视为列表”。

ps 最烦人的是我发现这个区别:

x = np.arange(9).reshape(3, 3)
# x[[False, True, True], True]
# array([[3, 4, 5],
#        [6, 7, 8]])
x[[False, True, True], False]
# IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (2,) (0,) 

所有3条评论

对于布尔数组,代码假定您正在尝试同时索引单个维度所有元素 - 不幸的是,这种选择以允许删除单个True的方式猜测。 即,它将您的row_mask, col_mask转换为 (2,3) 布尔数组,然后发现它无法索引 (3,3) 数组。

部分问题在于元组和列表被视为等价的,这是我们试图摆脱的。 最终,您将通过确保掩码是双列表来处理布尔数组索引。

不过,就目前而言,我担心唯一的解决方案是执行x[row_mask][:, col_mask]

cc @eric-wieser,他一直致力于弃用索引操作的“将元组视为列表”。

ps 最烦人的是我发现这个区别:

x = np.arange(9).reshape(3, 3)
# x[[False, True, True], True]
# array([[3, 4, 5],
#        [6, 7, 8]])
x[[False, True, True], False]
# IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (2,) (0,) 

是的, x[row_mask][:, col_mask]就是我最终要做的。 感谢您的解释,我很高兴这是正在研究的东西。

我认为arr[np.ix_(index)]是您在这里想要/期望的,或者在 NEP 21 中的外部索引逻辑的外在的话: https :

也许这会在一段时间内得到解决。 NEP 还表示,至少对于当前的索引,多个布尔索引应该被弃用(我认为是否允许这个特定用例可能仍然存在争议——它是一致的,但可能没有太多用例并且在任何情况)。

此页面是否有帮助?
0 / 5 - 0 等级

相关问题

jakirkham picture jakirkham  ·  55评论

valentinstn picture valentinstn  ·  61评论

ricardoV94 picture ricardoV94  ·  53评论

mrava87 picture mrava87  ·  53评论

numpy-gitbot picture numpy-gitbot  ·  49评论