您好, 欢迎来到 !    登录 | 注册 | | 设为首页 | 收藏本站

Python,numpy,einsum乘以一堆矩阵

Python,numpy,einsum乘以一堆矩阵

我认为使用numpy不可能有效地做到这一点(cumprod尽管解决方案很优雅)。我会用这种情况f2py。这是我所知道的调用更快语言的最简单方法,只需要一个额外的文件即可。

fortran.f90:

subroutine multimul(a, b)
  implicit none
  real(8), intent(in)  :: a(:,:,:,:)
  real(8), intent(out) :: b(size(a,1),size(a,2),size(a,3))
  real(8) :: work(size(a,1),size(a,2))
  integer i, j, k, l, m
  !$omp parallel do private(work,i,j)
  do i = 1, size(b,3)
    b(:,:,i) = a(:,:,i,size(a,4)) 
    do j = size(a,4)-1, 1, -1
      work = matmul(b(:,:,i),a(:,:,i,j))
      b(:,:,i) = work
    end do
  end do
end subroutine

编译f2py -c -m fortran fortran.f90(或F90FLAGS="-fopenmp" f2py -c -m fortran fortran.f90 -lgomp启用OpenMP加速)。然后您将在脚本中使用它

import numpy as np, fmuls
Arr = np.random.standard_normal([500,201,2,2])
def loopMult(Arr):
  ArrMult = Arr[0]
  for i in range(1,len(Arr)):
    ArrMult = np.einsum('fij,fjk->fik', ArrMult, Arr[i])
  return ArrMult
def myeinsum(A1, A2):
  return np.einsum('fij,fjk->fik', A1, A2)
A1 = loopMult(Arr)
A2 = reduce(myeinsum, Arr)
A3 = fmuls.multimul(Arr.T).T
print np.allclose(A1,A2)
print np.allclose(A1,A3)
%timeit loopMult(Arr)
%timeit reduce(myeinsum, Arr)
%timeit fmuls.multimul(Arr.T).T

哪个输出

True
True
10 loops, best of 3: 48.4 ms per loop
10 loops, best of 3: 48.8 ms per loop
100 loops, best of 3: 5.82 ms per loop

这就是加速因素8。所有转置的原因是f2py隐式转置所有数组,我们需要手动转置它们以告诉我们我们的fortran代码期望事物被转置。这避免了复制操作。代价是我们的每个2x2矩阵都是换位的,因此为了避免执行错误的操作,我们必须反向循环。

大于8的加速比应该是可能的-我没有花任何时间来优化它。

python 2022/1/1 18:45:10 有478人围观

撰写回答


你尚未登录,登录后可以

和开发者交流问题的细节

关注并接收问题和回答的更新提醒

参与内容的编辑和改进,让解决方法与时俱进

请先登录

推荐问题


联系我
置顶