Skip to content

Commit

Permalink
make cistring.gen_occslst a public function
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsung authored and sunqm committed Apr 29, 2023
1 parent 0b957fb commit 4706c95
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 14 deletions.
2 changes: 1 addition & 1 deletion examples/fci/11-large_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# Output all determinants coefficients
print(' det-alpha, det-beta, CI coefficients')
occslst = fci.cistring._gen_occslst(range(ncas), nelec//2)
occslst = fci.cistring.gen_occslst(range(ncas), nelec//2)
for i,occsa in enumerate(occslst):
for j,occsb in enumerate(occslst):
print(' %s %s %.12f' % (occsa, occsb, mc.ci[i,j]))
Expand Down
4 changes: 2 additions & 2 deletions pyscf/fci/addons.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,9 @@ def irrep_id2lz(irrep_id):
raise NotImplementedError
orb_lz = wfn_lz = d2h_wfnsym_id = None

occslsta = occslstb = cistring._gen_occslst(range(norb), neleca)
occslsta = occslstb = cistring.gen_occslst(range(norb), neleca)
if neleca != nelecb:
occslstb = cistring._gen_occslst(range(norb), nelecb)
occslstb = cistring.gen_occslst(range(norb), nelecb)
na = len(occslsta)
nb = len(occslsta)

Expand Down
22 changes: 19 additions & 3 deletions pyscf/fci/cistring.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def make_strings(orb_list, nelec):
'''
orb_list = list(orb_list)
if len(orb_list) > 63:
return _gen_occslst(orb_list, nelec)
return gen_occslst(orb_list, nelec)

assert (nelec >= 0)
if nelec == 0:
Expand All @@ -68,8 +68,24 @@ def gen_str_iter(orb_list, nelec):
return numpy.asarray(strings, dtype=numpy.int64)
gen_strings4orblist = make_strings

def _gen_occslst(orb_list, nelec):
def gen_occslst(orb_list, nelec):
'''Generate occupied orbital list for each string.
Returns:
List of lists of int32. Each inner list has length equal to the number of
electrons, and contains the occupied orbitals in the corresponding string.
Example:
>>> [bin(x) for x in make_strings((0, 1, 2, 3), 2)]
['0b11', '0b101', '0b110', '0b1001', '0b1010', '0b1100']
>>> gen_occslst((0, 1, 2, 3), 2)
OIndexList([[0, 1],
[0, 2],
[1, 2],
[0, 3],
[1, 3],
[2, 3]], dtype=int32)
'''
orb_list = list(orb_list)
assert (nelec >= 0)
Expand Down Expand Up @@ -143,7 +159,7 @@ def gen_linkstr_index_o1(orb_list, nelec, strs=None, tril=False):
return numpy.zeros((0,0,4), dtype=numpy.int32)

if strs is None:
strs = _gen_occslst(orb_list, nelec)
strs = gen_occslst(orb_list, nelec)
occslst = strs

orb_list = numpy.asarray(orb_list)
Expand Down
4 changes: 2 additions & 2 deletions pyscf/fci/direct_spin1.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def make_hdiag(h1e, eri, norb, nelec):
neleca, nelecb = _unpack_nelec(nelec)
h1e = numpy.asarray(h1e, order='C')
eri = ao2mo.restore(1, eri, norb)
occslsta = occslstb = cistring._gen_occslst(range(norb), neleca)
occslsta = occslstb = cistring.gen_occslst(range(norb), neleca)
if neleca != nelecb:
occslstb = cistring._gen_occslst(range(norb), nelecb)
occslstb = cistring.gen_occslst(range(norb), nelecb)
na = len(occslsta)
nb = len(occslstb)

Expand Down
4 changes: 2 additions & 2 deletions pyscf/fci/direct_uhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ def make_hdiag(h1e, eri, norb, nelec):
g2e_ab = ao2mo.restore(1, eri[1], norb)
g2e_bb = ao2mo.restore(1, eri[2], norb)

occslsta = occslstb = cistring._gen_occslst(range(norb), neleca)
occslsta = occslstb = cistring.gen_occslst(range(norb), neleca)
if neleca != nelecb:
occslstb = cistring._gen_occslst(range(norb), nelecb)
occslstb = cistring.gen_occslst(range(norb), nelecb)
na = len(occslsta)
nb = len(occslstb)

Expand Down
2 changes: 1 addition & 1 deletion pyscf/fci/fci_dhf_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def absorb_h1e(h1e, eri, norb, nelec, fac=1):


def make_hdiag(h1e, eri, norb, nelec, opt=None):
occslist = cistring._gen_occslst(range(norb), nelec)
occslist = cistring.gen_occslst(range(norb), nelec)
diagjk = numpy.einsum('iijj->ij', eri.copy(), optimize=True)
diagjk -= numpy.einsum('ijji->ij', eri, optimize=True)
hdiag = []
Expand Down
4 changes: 2 additions & 2 deletions pyscf/fci/fci_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def make_hdiag(h1e, eri, norb, nelec, opt=None):
else:
neleca, nelecb = nelec

occslista = cistring._gen_occslst(range(norb), neleca)
occslistb = cistring._gen_occslst(range(norb), nelecb)
occslista = cistring.gen_occslst(range(norb), neleca)
occslistb = cistring.gen_occslst(range(norb), nelecb)
eri = ao2mo.restore(1, eri, norb)
diagj = numpy.einsum('iijj->ij', eri)
diagk = numpy.einsum('ijji->ij', eri)
Expand Down
2 changes: 1 addition & 1 deletion pyscf/fci/test/test_cistring.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_strings4orblist(self):
self.assertEqual(bin(x), ref[i])

strs = cistring.gen_strings4orblist(range(8), 4)
occlst = cistring._gen_occslst(range(8), 4)
occlst = cistring.gen_occslst(range(8), 4)
self.assertAlmostEqual(abs(occlst - cistring._strs2occslst(strs, 8)).sum(), 0, 12)
self.assertAlmostEqual(abs(strs - cistring._occslst2strs(occlst)).sum(), 0, 12)

Expand Down

0 comments on commit 4706c95

Please sign in to comment.