提问者:小点点

如果我只需要绕对角线的带,如何加速numpy中的矩阵乘法?


我需要计算方阵a(a*a^t)的二次幂,但我只对结果对角线附近的值感兴趣。 换句话说,我需要计算相邻行的点积,其中邻域是由某个固定大小的窗口定义的,理想情况下,我希望避免计算剩余的点积。 如何在numpy中做到这一点,而不运行带有一些掩蔽的全矩阵乘法? 结果数组应如下所示:

a1*a1  a1*a2  0      0      0      0
a2*a1  a2*a2  a2*a3  0      0      0
0      a3*a2  a3*a3  a3*a4  0      0
0      0      a4*a3  a4*a4  a4*a5  0...
0      0      0      ...
...

示例矩阵包含相邻行的点积。 每行只与左右相邻行相乘。 为节省时间,理想情况下不应通过解计算零点。 这条线索似乎指向类似的方向。


共1个答案

匿名用户

使用scipy查看稀疏矩阵(其中numpy也来自)。

对于您的具体问题:

>

  • 对角线元素是矩阵及其转置的元素乘积的列和v=np.sum(np.multiply(A,A.t),axis=0)

    off对角线元素是相同的,只是删除了最后一行/列,并用第一个索引处的零列/行替代:

    pos_offset = np.concatenate((np.zeros((n, 1)), A[:, :-1]), axis=1)
    v_pos = np.sum(np.multiply(A, pos_offset.T), axis=0)
    # similar for the negative offset diagonal
    
    A_res = np.diag(v) + np.diag(v_pos, k=1) + np.diag(v_neg, k=-1)