Skip to content

Commit

Permalink
efficiency of transpose for derivative of MO cderi
Browse files Browse the repository at this point in the history
  • Loading branch information
ajz34 committed May 22, 2021
1 parent 8ae4412 commit 7943e50
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
9 changes: 9 additions & 0 deletions pyscf/dh/dhutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,12 @@ def restricted_biorthogonalize(t_ijab, cc, c_os, c_ss):
res *= coef_1
res += coef_0 * t_ijab
return res


def hermi_sum_last2dim_inplace(tsr, hermi=1):
# shameless call lib.hermi_sum, just for a tensor wrapper
tsr_shape = tsr.shape
tsr.shape = (-1, tsr.shape[-1], tsr.shape[-2])
res = lib.hermi_sum(tsr, axes=(0, 2, 1), hermi=hermi, inplace=True)
res.shape = tsr_shape
return res
8 changes: 5 additions & 3 deletions pyscf/dh/polar/rdfdh.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dh import RDFDH
from dh.dhutil import gen_batch, get_rho_from_dm_gga, restricted_biorthogonalize
from dh.dhutil import gen_batch, get_rho_from_dm_gga, restricted_biorthogonalize, hermi_sum_last2dim_inplace
from pyscf import gto, lib, dft
import numpy as np

Expand Down Expand Up @@ -265,12 +265,14 @@ def get_SCR3(self):
pdA_G_blk = np.asarray(pdA_G_ia_ri[:, saux])
# pdA_Y_ij part
pdA_Y_blk = einsum("Ami, Pmj -> APij", U_1[:, :, so], Y_blk[:, :, so])
pdA_Y_blk += pdA_Y_blk.swapaxes(-1, -2)
# pdA_Y_blk += pdA_Y_blk.swapaxes(-1, -2)
hermi_sum_last2dim_inplace(pdA_Y_blk)
SCR3 -= 4 * einsum("APja, Pij -> Aai", pdA_G_blk, Y_blk[:, so, so])
SCR3 -= 4 * einsum("Pja, APij -> Aai", G_blk, pdA_Y_blk)
# pdA_Y_ab part
pdA_Y_blk = einsum("Ama, Pmb -> APab", U_1[:, :, sv], Y_blk[:, :, sv])
pdA_Y_blk += pdA_Y_blk.swapaxes(-1, -2)
# pdA_Y_blk += pdA_Y_blk.swapaxes(-1, -2)
hermi_sum_last2dim_inplace(pdA_Y_blk)
SCR3 += 4 * einsum("APib, Pab -> Aai", pdA_G_blk, Y_blk[:, sv, sv])
SCR3 += 4 * einsum("Pib, APab -> Aai", G_blk, pdA_Y_blk)
if self.xc_n:
Expand Down

0 comments on commit 7943e50

Please sign in to comment.