在numpy或pytorch中自动获取对角矩阵条纹


问题内容

我需要获得矩阵的对角线条纹(不确定此处的术语,对角线矩阵条纹似乎最能描述它)。

假设我有一个大小为KxN的矩阵,其中K和N是任意大小,K> N。说,我有一个矩阵:

[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]]

我需要从中提取对角线条纹,在这种情况下,是通过截断原始矩阵而创建的矩阵MxV大小:

[[ 0  x  x]
 [ 3  4  x]
 [ x  7  8]
 [ x  x  11]]

因此,结果矩阵为:

[[ 0  4  8]
 [ 3  7  11]]

这是一个使用矩阵遮罩来去除遮罩位置的小示例代码:

import numpy as np
X=np.arange(12).reshape(4,3)
mask=np.asarray([
  [ True,  False,  False],
  [ True,  True,  False], 
  [ False, True,  True], 
  [ False, False,  True]
])

>>> mask
array([[ True, False, False],
       [ True,  True, False],
       [False,  True,  True],
       [False, False,  True]], dtype=bool)

>>> X
array([[ 0,  1,  2],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11]])

>>> X.T[mask.T].reshape(3,2).T
array([[ 0,  4,  8],
       [ 3,  7, 11]])

但我看不到如何自动将这样的蒙版生成为任何K和N尺寸(ei 39x9或360x96)

任何帮助表示赞赏。也许有一些功能可以自动在numpy,scipy或pytorch中执行此操作?

编辑:

我还有另一个问题,是否有可能得到:

[[ 0  x  x]
 [ 3  4  x]
 [ x  7  8]
 [ x  x  11]]

要获得像这样的反向条纹:

[[ x   x   2]
 [ x   4   5]
 [ 6   7   x]
 [ 9   x   x]]

问题答案:

stride_tricks 做到这一点:

>>> import numpy as np
>>> 
>>> def stripe(a):
...    a = np.asanyarray(a)
...    *sh, i, j = a.shape
...    assert i >= j
...    *st, k, m = a.strides
...    return np.lib.stride_tricks.as_strided(a, (*sh, i-j+1, j), (*st, k, k+m))
... 
>>> a = np.arange(24).reshape(6, 4)
>>> a
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15],
       [16, 17, 18, 19],
       [20, 21, 22, 23]])
>>> stripe(a)
array([[ 0,  5, 10, 15],
       [ 4,  9, 14, 19],
       [ 8, 13, 18, 23]])

如果a是一个数组,则会创建一个可写的视图,这意味着,如果您觉得这样,可以执行以下操作:

>>> stripe(a)[...] *= 10
>>> a
array([[  0,   1,   2,   3],
       [ 40,  50,   6,   7],
       [ 80,  90, 100,  11],
       [ 12, 130, 140, 150],
       [ 16,  17, 180, 190],
       [ 20,  21,  22, 230]])

更新:可以相同的方式获得从左下到右上的条纹。仅很小的复杂性:它不基于与原始数组相同的地址。

>>> def reverse_stripe(a):
...     a = np.asanyarray(a)
...     *sh, i, j = a.shape
...     assert i >= j
...     *st, k, m = a.strides
...     return np.lib.stride_tricks.as_strided(a[..., j-1:, :], (*sh, i-j+1, j), (*st, k, m-k))
... 
>>> a = np.arange(24).reshape(6, 4)
>>> reverse_stripe(a)
array([[12,  9,  6,  3],
       [16, 13, 10,  7],
       [20, 17, 14, 11]])