From be776038b2600b78820805faf1c1c73eb334036a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 23 May 2023 17:15:45 -0400 Subject: [PATCH 1/3] allow custom data for deepmd format Signed-off-by: Jinzhe Zeng --- dpdata/deepmd/comp.py | 35 ++++++++++++++++++++++ dpdata/deepmd/hdf5.py | 35 ++++++++++++++++++++++ dpdata/deepmd/raw.py | 27 +++++++++++++++++ tests/test_custom_data_type.py | 53 ++++++++++++++++++++++++++++++++++ 4 files changed, 150 insertions(+) create mode 100644 tests/test_custom_data_type.py diff --git a/dpdata/deepmd/comp.py b/dpdata/deepmd/comp.py index 66bf4ee66..933edb099 100644 --- a/dpdata/deepmd/comp.py +++ b/dpdata/deepmd/comp.py @@ -1,9 +1,11 @@ import glob import os import shutil +import warnings import numpy as np +import dpdata from .raw import load_type @@ -60,6 +62,23 @@ def to_system_data(folder, type_map=None, labels=True): data["forces"] = np.concatenate(all_forces, axis=0) if len(all_virs) > 0: data["virials"] = np.concatenate(all_virs, axis=0) + # allow custom dtypes + if labels: + for dtype in dpdata.system.LabeledSystem.DTYPES: + if dtype.name in ("atom_numbs", "atom_names", "atom_types", "orig", "cells", "coords", "real_atom_types", "real_atom_names", "nopbc", "energies", "forces", "virials"): + # skip as these data contains specific rules + continue + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn(f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/npy format.") + continue + shape = [-1 if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:]] + all_data = [] + for ii in sets: + tmp = _cond_load_data(os.path.join(ii, dtype.name + ".npy")) + if tmp is not None: + all_data.append(np.reshape(tmp, [tmp.shape[0], *shape])) + if len(all_data) > 0: + data[dtype.name] = np.concatenate(all_data, axis=0) return data @@ -131,3 +150,19 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True): if data.get("nopbc", False): with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc: pass + # allow custom dtypes + for dtype in dpdata.system.LabeledSystem.DTYPES: + if dtype.name in ("atom_numbs", "atom_names", "atom_types", "orig", "cells", "coords", "real_atom_types", "real_atom_names", "nopbc", "energies", "forces", "virials"): + # skip as these data contains specific rules + continue + if dtype.name not in data: + continue + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn(f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/npy format.") + continue + ddata = np.reshape(data[dtype.name], [nframes, -1]).astype(comp_prec) + for ii in range(nsets): + set_stt = ii * set_size + set_end = (ii + 1) * set_size + set_folder = os.path.join(folder, "set.%03d" % ii) + np.save(os.path.join(set_folder, dtype.name), ddata[set_stt:set_end]) diff --git a/dpdata/deepmd/hdf5.py b/dpdata/deepmd/hdf5.py index 1afb64894..afe8618da 100644 --- a/dpdata/deepmd/hdf5.py +++ b/dpdata/deepmd/hdf5.py @@ -1,10 +1,13 @@ """Utils for deepmd/hdf5 format.""" from typing import Optional, Union +import warnings import h5py import numpy as np from wcmatch.glob import globfilter +import dpdata + __all__ = ["to_system_data", "dump"] @@ -92,6 +95,22 @@ def to_system_data( "required": False, }, } + # allow custom dtypes + for dtype in dpdata.system.LabeledSystem.DTYPES: + if dtype.name in ("atom_numbs", "atom_names", "atom_types", "orig", "cells", "coords", "real_atom_types", "real_atom_names", "nopbc", "energies", "forces", "virials"): + # skip as these data contains specific rules + continue + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn(f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/hdf5 format.") + continue + + data_types[dtype.name] = { + "fn": dtype.name, + "labeled": True, + "shape": dtype.shape[1:], + "required": False, + } + for dt, prop in data_types.items(): all_data = [] @@ -167,6 +186,22 @@ def dump( "forces": {"fn": "force", "shape": (nframes, -1), "dump": True}, "virials": {"fn": "virial", "shape": (nframes, 9), "dump": True}, } + + # allow custom dtypes + for dtype in dpdata.system.LabeledSystem.DTYPES: + if dtype.name in ("atom_numbs", "atom_names", "atom_types", "orig", "cells", "coords", "real_atom_types", "real_atom_names", "nopbc", "energies", "forces", "virials"): + # skip as these data contains specific rules + continue + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn(f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/hdf5 format.") + continue + + data_types[dtype.name] = { + "fn": dtype.name, + "shape": (nframes, -1), + "dump": True, + } + for dt, prop in data_types.items(): if dt in data: if prop["dump"]: diff --git a/dpdata/deepmd/raw.py b/dpdata/deepmd/raw.py index 2f2021d44..9719c0178 100644 --- a/dpdata/deepmd/raw.py +++ b/dpdata/deepmd/raw.py @@ -1,7 +1,10 @@ import os +import warnings import numpy as np +import dpdata + def load_type(folder, type_map=None): data = {} @@ -57,6 +60,18 @@ def to_system_data(folder, type_map=None, labels=True): data["virials"] = np.reshape(data["virials"], [nframes, 3, 3]) if os.path.isfile(os.path.join(folder, "nopbc")): data["nopbc"] = True + # allow custom dtypes + if labels: + for dtype in dpdata.system.LabeledSystem.DTYPES: + if dtype.name in ("atom_numbs", "atom_names", "atom_types", "orig", "cells", "coords", "real_atom_types", "real_atom_names", "nopbc", "energies", "forces", "virials"): + # skip as these data contains specific rules + continue + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn(f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/raw format.") + continue + shape = [-1 if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:]] + if os.path.exists(os.path.join(folder, f"{dtype.name}.raw")): + data[dtype.name] = np.reshape(np.loadtxt(os.path.join(folder, f"{dtype.name}.raw")), [nframes, *shape]) return data else: raise RuntimeError("not dir " + folder) @@ -102,3 +117,15 @@ def dump(folder, data): if data.get("nopbc", False): with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc: pass + # allow custom dtypes + for dtype in dpdata.system.LabeledSystem.DTYPES: + if dtype.name in ("atom_numbs", "atom_names", "atom_types", "orig", "cells", "coords", "real_atom_types", "real_atom_names", "nopbc", "energies", "forces", "virials"): + # skip as these data contains specific rules + continue + if dtype.name not in data: + continue + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn(f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/raw format.") + continue + ddata = np.reshape(data[dtype.name], [nframes, -1]) + np.savetxt(os.path.join(folder, f"{dtype.name}.raw"), ddata) diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py new file mode 100644 index 000000000..946d48978 --- /dev/null +++ b/tests/test_custom_data_type.py @@ -0,0 +1,53 @@ +import unittest + +import numpy as np +import h5py + +import dpdata +from dpdata.system import Axis, DataType + +class TestDeepmdLoadDumpComp(unittest.TestCase): + def setUp(self): + self.backup = dpdata.system.LabeledSystem.DTYPES + dpdata.system.LabeledSystem.DTYPES = dpdata.system.LabeledSystem.DTYPES + ( + DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False), + ) + self.system = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar") + self.foo = np.ones((len(self.system), 2, 4)) + self.system.data["foo"] = self.foo + self.system.check_data() + + def tearDown(self) -> None: + dpdata.system.LabeledSystem.DTYPES = self.backup + + def test_to_deepmd_raw(self): + self.system.to_deepmd_raw("data_foo") + foo = np.loadtxt("data_foo/foo.raw") + np.testing.assert_allclose(foo.reshape(self.foo.shape), self.foo) + + def test_from_deepmd_raw(self): + self.system.to_deepmd_raw("data_foo") + x = dpdata.LabeledSystem("data_foo", fmt="deepmd/raw") + np.testing.assert_allclose(x.data["foo"], self.foo) + + def test_to_deepmd_npy(self): + self.system.to_deepmd_npy("data_foo") + foo = np.load("data_foo/set.000/foo.npy") + np.testing.assert_allclose(foo.reshape(self.foo.shape), self.foo) + + def test_from_deepmd_npy(self): + self.system.to_deepmd_npy("data_foo") + x = dpdata.LabeledSystem("data_foo", fmt="deepmd/npy") + np.testing.assert_allclose(x.data["foo"], self.foo) + + def test_to_deepmd_hdf5(self): + self.system.to_deepmd_hdf5("data_foo.h5") + with h5py.File("data_foo.h5") as f: + foo = f["set.000/foo.npy"][:] + #foo = np.load("data_foo/set.000/foo.npy") + np.testing.assert_allclose(foo.reshape(self.foo.shape), self.foo) + + def test_from_deepmd_hdf5(self): + self.system.to_deepmd_hdf5("data_foo.h5") + x = dpdata.LabeledSystem("data_foo.h5", fmt="deepmd/hdf5") + np.testing.assert_allclose(x.data["foo"], self.foo) From ac03342e6b8d14660dd8b8cc0c01169b52cb25b2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 May 2023 21:18:55 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpdata/deepmd/comp.py | 47 ++++++++++++++++++++++++----- dpdata/deepmd/hdf5.py | 43 ++++++++++++++++++++++----- dpdata/deepmd/raw.py | 54 +++++++++++++++++++++++++++++----- tests/test_custom_data_type.py | 5 ++-- 4 files changed, 125 insertions(+), 24 deletions(-) diff --git a/dpdata/deepmd/comp.py b/dpdata/deepmd/comp.py index 933edb099..9d63e1dd4 100644 --- a/dpdata/deepmd/comp.py +++ b/dpdata/deepmd/comp.py @@ -6,6 +6,7 @@ import numpy as np import dpdata + from .raw import load_type @@ -65,13 +66,30 @@ def to_system_data(folder, type_map=None, labels=True): # allow custom dtypes if labels: for dtype in dpdata.system.LabeledSystem.DTYPES: - if dtype.name in ("atom_numbs", "atom_names", "atom_types", "orig", "cells", "coords", "real_atom_types", "real_atom_names", "nopbc", "energies", "forces", "virials"): + if dtype.name in ( + "atom_numbs", + "atom_names", + "atom_types", + "orig", + "cells", + "coords", + "real_atom_types", + "real_atom_names", + "nopbc", + "energies", + "forces", + "virials", + ): # skip as these data contains specific rules continue - if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): - warnings.warn(f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/npy format.") + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn( + f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/npy format." + ) continue - shape = [-1 if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:]] + shape = [ + -1 if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:] + ] all_data = [] for ii in sets: tmp = _cond_load_data(os.path.join(ii, dtype.name + ".npy")) @@ -152,13 +170,28 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True): pass # allow custom dtypes for dtype in dpdata.system.LabeledSystem.DTYPES: - if dtype.name in ("atom_numbs", "atom_names", "atom_types", "orig", "cells", "coords", "real_atom_types", "real_atom_names", "nopbc", "energies", "forces", "virials"): + if dtype.name in ( + "atom_numbs", + "atom_names", + "atom_types", + "orig", + "cells", + "coords", + "real_atom_types", + "real_atom_names", + "nopbc", + "energies", + "forces", + "virials", + ): # skip as these data contains specific rules continue if dtype.name not in data: continue - if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): - warnings.warn(f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/npy format.") + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn( + f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/npy format." + ) continue ddata = np.reshape(data[dtype.name], [nframes, -1]).astype(comp_prec) for ii in range(nsets): diff --git a/dpdata/deepmd/hdf5.py b/dpdata/deepmd/hdf5.py index afe8618da..b4ae1a3c6 100644 --- a/dpdata/deepmd/hdf5.py +++ b/dpdata/deepmd/hdf5.py @@ -1,6 +1,6 @@ """Utils for deepmd/hdf5 format.""" -from typing import Optional, Union import warnings +from typing import Optional, Union import h5py import numpy as np @@ -97,11 +97,26 @@ def to_system_data( } # allow custom dtypes for dtype in dpdata.system.LabeledSystem.DTYPES: - if dtype.name in ("atom_numbs", "atom_names", "atom_types", "orig", "cells", "coords", "real_atom_types", "real_atom_names", "nopbc", "energies", "forces", "virials"): + if dtype.name in ( + "atom_numbs", + "atom_names", + "atom_types", + "orig", + "cells", + "coords", + "real_atom_types", + "real_atom_names", + "nopbc", + "energies", + "forces", + "virials", + ): # skip as these data contains specific rules continue - if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): - warnings.warn(f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/hdf5 format.") + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn( + f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/hdf5 format." + ) continue data_types[dtype.name] = { @@ -111,7 +126,6 @@ def to_system_data( "required": False, } - for dt, prop in data_types.items(): all_data = [] @@ -189,11 +203,26 @@ def dump( # allow custom dtypes for dtype in dpdata.system.LabeledSystem.DTYPES: - if dtype.name in ("atom_numbs", "atom_names", "atom_types", "orig", "cells", "coords", "real_atom_types", "real_atom_names", "nopbc", "energies", "forces", "virials"): + if dtype.name in ( + "atom_numbs", + "atom_names", + "atom_types", + "orig", + "cells", + "coords", + "real_atom_types", + "real_atom_names", + "nopbc", + "energies", + "forces", + "virials", + ): # skip as these data contains specific rules continue if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): - warnings.warn(f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/hdf5 format.") + warnings.warn( + f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/hdf5 format." + ) continue data_types[dtype.name] = { diff --git a/dpdata/deepmd/raw.py b/dpdata/deepmd/raw.py index 9719c0178..fdb2fc649 100644 --- a/dpdata/deepmd/raw.py +++ b/dpdata/deepmd/raw.py @@ -63,15 +63,38 @@ def to_system_data(folder, type_map=None, labels=True): # allow custom dtypes if labels: for dtype in dpdata.system.LabeledSystem.DTYPES: - if dtype.name in ("atom_numbs", "atom_names", "atom_types", "orig", "cells", "coords", "real_atom_types", "real_atom_names", "nopbc", "energies", "forces", "virials"): + if dtype.name in ( + "atom_numbs", + "atom_names", + "atom_types", + "orig", + "cells", + "coords", + "real_atom_types", + "real_atom_names", + "nopbc", + "energies", + "forces", + "virials", + ): # skip as these data contains specific rules continue - if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): - warnings.warn(f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/raw format.") + if not ( + len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES + ): + warnings.warn( + f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/raw format." + ) continue - shape = [-1 if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:]] + shape = [ + -1 if xx == dpdata.system.Axis.NATOMS else xx + for xx in dtype.shape[1:] + ] if os.path.exists(os.path.join(folder, f"{dtype.name}.raw")): - data[dtype.name] = np.reshape(np.loadtxt(os.path.join(folder, f"{dtype.name}.raw")), [nframes, *shape]) + data[dtype.name] = np.reshape( + np.loadtxt(os.path.join(folder, f"{dtype.name}.raw")), + [nframes, *shape], + ) return data else: raise RuntimeError("not dir " + folder) @@ -119,13 +142,28 @@ def dump(folder, data): pass # allow custom dtypes for dtype in dpdata.system.LabeledSystem.DTYPES: - if dtype.name in ("atom_numbs", "atom_names", "atom_types", "orig", "cells", "coords", "real_atom_types", "real_atom_names", "nopbc", "energies", "forces", "virials"): + if dtype.name in ( + "atom_numbs", + "atom_names", + "atom_types", + "orig", + "cells", + "coords", + "real_atom_types", + "real_atom_names", + "nopbc", + "energies", + "forces", + "virials", + ): # skip as these data contains specific rules continue if dtype.name not in data: continue - if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): - warnings.warn(f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/raw format.") + if not (len(dtype.shape) and dtype.shape[0] == dpdata.system.Axis.NFRAMES): + warnings.warn( + f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/raw format." + ) continue ddata = np.reshape(data[dtype.name], [nframes, -1]) np.savetxt(os.path.join(folder, f"{dtype.name}.raw"), ddata) diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index 946d48978..e813c5ed1 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -1,11 +1,12 @@ import unittest -import numpy as np import h5py +import numpy as np import dpdata from dpdata.system import Axis, DataType + class TestDeepmdLoadDumpComp(unittest.TestCase): def setUp(self): self.backup = dpdata.system.LabeledSystem.DTYPES @@ -44,7 +45,7 @@ def test_to_deepmd_hdf5(self): self.system.to_deepmd_hdf5("data_foo.h5") with h5py.File("data_foo.h5") as f: foo = f["set.000/foo.npy"][:] - #foo = np.load("data_foo/set.000/foo.npy") + # foo = np.load("data_foo/set.000/foo.npy") np.testing.assert_allclose(foo.reshape(self.foo.shape), self.foo) def test_from_deepmd_hdf5(self): From 0b0dd8e5be60b3f66c99dab3c226aa309450af74 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 23 May 2023 17:19:45 -0400 Subject: [PATCH 3/3] Update test_custom_data_type.py --- tests/test_custom_data_type.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index e813c5ed1..58b9f5e49 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -45,7 +45,6 @@ def test_to_deepmd_hdf5(self): self.system.to_deepmd_hdf5("data_foo.h5") with h5py.File("data_foo.h5") as f: foo = f["set.000/foo.npy"][:] - # foo = np.load("data_foo/set.000/foo.npy") np.testing.assert_allclose(foo.reshape(self.foo.shape), self.foo) def test_from_deepmd_hdf5(self):