|
1 | 1 | import jax.numpy as jnp
|
2 | 2 |
|
3 |
| -from varipeps.ctmrg import calc_ctmrg_env |
| 3 | +from varipeps import varipeps_config |
| 4 | +from varipeps.ctmrg import calc_ctmrg_env, CTMRGNotConvergedError |
4 | 5 | from varipeps.peps import PEPS_Unit_Cell
|
5 | 6 |
|
6 | 7 | from . import overlap_single_site
|
|
12 | 13 | }
|
13 | 14 |
|
14 | 15 |
|
15 |
| -def calculate_overlap(unitcell_A, unitcell_B, chi, max_chi): |
| 16 | +def calculate_overlap( |
| 17 | + unitcell_A, unitcell_B, chi, max_chi, *, test_automatic_overlap_conv=True |
| 18 | +): |
16 | 19 | structure_A = tuple(tuple(i) for i in unitcell_A.data.structure)
|
17 | 20 | structure_B = tuple(tuple(i) for i in unitcell_B.data.structure)
|
18 | 21 |
|
@@ -43,10 +46,40 @@ def calculate_overlap(unitcell_A, unitcell_B, chi, max_chi):
|
43 | 46 |
|
44 | 47 | overlap_unitcell = PEPS_Unit_Cell.from_tensor_list(overlap_tensors, structure_A)
|
45 | 48 |
|
46 |
| - overlap_unitcell, _ = calc_ctmrg_env( |
47 |
| - [i.tensor for i in overlap_tensors], overlap_unitcell |
48 |
| - ) |
| 49 | + if test_automatic_overlap_conv: |
| 50 | + tmp_max_steps = varipeps_config.ctmrg_max_steps |
| 51 | + tmp_fail = varipeps_config.ctmrg_fail_if_not_converged |
49 | 52 |
|
50 |
| - overlap_AB = overlap_func(overlap_unitcell) |
| 53 | + varipeps_config.ctmrg_fail_if_not_converged = False |
| 54 | + varipeps_config.ctmrg_max_steps = 10 |
| 55 | + |
| 56 | + overlap_unitcell, _ = calc_ctmrg_env( |
| 57 | + [i.tensor for i in overlap_tensors], overlap_unitcell |
| 58 | + ) |
| 59 | + overlap_AB = overlap_func(overlap_unitcell) |
| 60 | + |
| 61 | + for count in range(tmp_max_steps): |
| 62 | + old_overlap_AB = overlap_AB |
| 63 | + |
| 64 | + overlap_unitcell, _ = calc_ctmrg_env( |
| 65 | + [i.tensor for i in overlap_tensors], overlap_unitcell |
| 66 | + ) |
| 67 | + overlap_AB = overlap_func(overlap_unitcell) |
| 68 | + |
| 69 | + if ( |
| 70 | + jnp.abs(overlap_AB - old_overlap_AB) |
| 71 | + <= varipeps_config.ctmrg_convergence_eps |
| 72 | + ): |
| 73 | + varipeps_config.ctmrg_fail_if_not_converged = tmp_fail |
| 74 | + varipeps_config.ctmrg_max_steps = tmp_max_steps |
| 75 | + break |
| 76 | + if count == (tmp_max_steps - 1): |
| 77 | + raise CTMRGNotConvergedError |
| 78 | + else: |
| 79 | + overlap_unitcell, _ = calc_ctmrg_env( |
| 80 | + [i.tensor for i in overlap_tensors], overlap_unitcell |
| 81 | + ) |
| 82 | + |
| 83 | + overlap_AB = overlap_func(overlap_unitcell) |
51 | 84 |
|
52 | 85 | return overlap_AB
|
0 commit comments