Skip to content

Commit 9fe2892

Browse files
committed
Fix storage of scalars in state files
1 parent c975139 commit 9fe2892

File tree

1 file changed

+57
-27
lines changed

1 file changed

+57
-27
lines changed

varipeps/optimization/optimizer.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -257,33 +257,51 @@ def autosave_function_restartable(
257257
)
258258
grp_old_grad.attrs["len"] = len(old_gradient)
259259
for i, g in enumerate(old_gradient):
260-
grp_old_grad.create_dataset(
261-
f"old_grad_{i:d}", data=g, compression="gzip", compression_opts=6
262-
)
260+
if g.ndim == 0:
261+
grp_old_grad.create_dataset(f"old_grad_{i:d}", data=g)
262+
else:
263+
grp_old_grad.create_dataset(
264+
f"old_grad_{i:d}",
265+
data=g,
266+
compression="gzip",
267+
compression_opts=6,
268+
)
263269

264270
if old_descent_dir is not None:
265271
grp_old_des_dir = grp_restart_data.create_group(
266272
"old_descent_dir", track_order=True
267273
)
268274
grp_old_des_dir.attrs["len"] = len(old_descent_dir)
269275
for i, d in enumerate(old_descent_dir):
270-
grp_old_des_dir.create_dataset(
271-
f"old_descent_dir_{i:d}",
272-
data=d,
273-
compression="gzip",
274-
compression_opts=6,
275-
)
276+
if d.ndim == 0:
277+
grp_old_des_dir.create_dataset(
278+
f"old_descent_dir_{i:d}",
279+
data=d,
280+
)
281+
else:
282+
grp_old_des_dir.create_dataset(
283+
f"old_descent_dir_{i:d}",
284+
data=d,
285+
compression="gzip",
286+
compression_opts=6,
287+
)
276288

277289
if best_unitcell is not None:
278290
grp_best_t = grp_restart_data.create_group("best_tensors", track_order=True)
279291
grp_best_t.attrs["len"] = len(best_tensors)
280292
for i, t in enumerate(best_tensors):
281-
grp_best_t.create_dataset(
282-
f"best_tensor_{i:d}",
283-
data=t,
284-
compression="gzip",
285-
compression_opts=6,
286-
)
293+
if t.ndim == 0:
294+
grp_best_t.create_dataset(
295+
f"best_tensor_{i:d}",
296+
data=t,
297+
)
298+
else:
299+
grp_best_t.create_dataset(
300+
f"best_tensor_{i:d}",
301+
data=t,
302+
compression="gzip",
303+
compression_opts=6,
304+
)
287305

288306
grp_best_u = grp_restart_data.create_group("best_unitcell")
289307
best_unitcell.save_to_group(grp_best_u, False)
@@ -317,18 +335,30 @@ def autosave_function_restartable(
317335
if len(x) != len(g) != grp_l_bfgs.attrs["len_elems"]:
318336
raise ValueError("L-BFGS list lengths mismatch.")
319337
for j in range(grp_l_bfgs.attrs["len_elems"]):
320-
grp_l_bfgs.create_dataset(
321-
f"x_{i:d}_{j:d}",
322-
data=x[j],
323-
compression="gzip",
324-
compression_opts=6,
325-
)
326-
grp_l_bfgs.create_dataset(
327-
f"grad_{i:d}_{j:d}",
328-
data=g[j],
329-
compression="gzip",
330-
compression_opts=6,
331-
)
338+
if x[j].ndim == 0:
339+
grp_l_bfgs.create_dataset(
340+
f"x_{i:d}_{j:d}",
341+
data=x[j],
342+
)
343+
else:
344+
grp_l_bfgs.create_dataset(
345+
f"x_{i:d}_{j:d}",
346+
data=x[j],
347+
compression="gzip",
348+
compression_opts=6,
349+
)
350+
if g[j].ndim == 0:
351+
grp_l_bfgs.create_dataset(
352+
f"grad_{i:d}_{j:d}",
353+
data=g[j],
354+
)
355+
else:
356+
grp_l_bfgs.create_dataset(
357+
f"grad_{i:d}_{j:d}",
358+
data=g[j],
359+
compression="gzip",
360+
compression_opts=6,
361+
)
332362

333363

334364
def _autosave_wrapper(

0 commit comments

Comments
 (0)