Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions dpdata/plugins/deepmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def from_system_mix(self, file_name, type_map=None, **kwargs):
file_name, type_map=type_map, labels=False
)

def to_system(self, data, file_name, prec=np.float64, **kwargs):
def to_system(
self, data, file_name, set_size: int = 2000, prec=np.float64, **kwargs
):
"""Dump the system in deepmd mixed type format (numpy binary) to `folder`.

The frames were already split to different systems, so these frames can be dumped to one single subfolders
Expand All @@ -107,12 +109,14 @@ def to_system(self, data, file_name, prec=np.float64, **kwargs):
System data
file_name : str
The output folder
set_size : int, default=2000
set size
prec : {numpy.float32, numpy.float64}
The floating point precision of the compressed data
**kwargs : dict
other parameters
"""
dpdata.deepmd.mixed.dump(file_name, data, comp_prec=prec)
dpdata.deepmd.mixed.dump(file_name, data, set_size=set_size, comp_prec=prec)

def from_labeled_system_mix(self, file_name, type_map=None, **kwargs):
return dpdata.deepmd.mixed.to_system_data(
Expand Down
101 changes: 101 additions & 0 deletions tests/test_deepmd_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,107 @@ def test_str(self):
)


class TestMixedMultiSystemsDumpLoadSetSize(
unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC
):
def setUp(self):
self.places = 6
self.e_places = 6
self.f_places = 6
self.v_places = 6

# C1H4
system_1 = dpdata.LabeledSystem(
"gaussian/methane.gaussianlog", fmt="gaussian/log"
)

# C1H3
system_2 = dpdata.LabeledSystem(
"gaussian/methane_sub.gaussianlog", fmt="gaussian/log"
)

tmp_data = system_1.data.copy()
tmp_data["atom_numbs"] = [1, 1, 1, 2]
tmp_data["atom_names"] = ["C", "H", "A", "B"]
tmp_data["atom_types"] = np.array([0, 1, 2, 3, 3])
# C1H1A1B2
system_1_modified_type_1 = dpdata.LabeledSystem(data=tmp_data)

tmp_data = system_1.data.copy()
tmp_data["atom_numbs"] = [1, 1, 2, 1]
tmp_data["atom_names"] = ["C", "H", "A", "B"]
tmp_data["atom_types"] = np.array([0, 1, 2, 2, 3])
# C1H1A2B1
system_1_modified_type_2 = dpdata.LabeledSystem(data=tmp_data)

tmp_data = system_1.data.copy()
tmp_data["atom_numbs"] = [1, 1, 1, 2]
tmp_data["atom_names"] = ["C", "H", "A", "D"]
tmp_data["atom_types"] = np.array([0, 1, 2, 3, 3])
# C1H1A1C2
system_1_modified_type_3 = dpdata.LabeledSystem(data=tmp_data)

self.ms = dpdata.MultiSystems(
system_1,
system_2,
system_1_modified_type_1,
system_1_modified_type_2,
system_1_modified_type_3,
)
self.ms.to_deepmd_npy_mixed("tmp.deepmd.mixed", set_size=1)
self.place_holder_ms = dpdata.MultiSystems()
self.place_holder_ms.from_deepmd_npy("tmp.deepmd.mixed", fmt="deepmd/npy")
self.systems = dpdata.MultiSystems()
self.systems.from_deepmd_npy_mixed("tmp.deepmd.mixed", fmt="deepmd/npy/mixed")
self.system_1 = self.ms["C1H4A0B0D0"]
self.system_2 = self.systems["C1H4A0B0D0"]
mixed_sets = glob("tmp.deepmd.mixed/*/set.*")
self.assertEqual(len(mixed_sets), 5)
for i in mixed_sets:
self.assertEqual(
os.path.exists(os.path.join(i, "real_atom_types.npy")), True
)

self.system_names = [
"C1H4A0B0D0",
"C1H3A0B0D0",
"C1H1A1B2D0",
"C1H1A2B1D0",
"C1H1A1B0D2",
]
self.system_sizes = {
"C1H4A0B0D0": 1,
"C1H3A0B0D0": 1,
"C1H1A1B2D0": 1,
"C1H1A2B1D0": 1,
"C1H1A1B0D2": 1,
}
self.atom_names = ["C", "H", "A", "B", "D"]

def tearDown(self):
if os.path.exists("tmp.deepmd.mixed"):
shutil.rmtree("tmp.deepmd.mixed")

def test_len(self):
self.assertEqual(len(self.ms), 5)
self.assertEqual(len(self.place_holder_ms), 2)
self.assertEqual(len(self.systems), 5)

def test_get_nframes(self):
self.assertEqual(self.ms.get_nframes(), 5)
self.assertEqual(self.place_holder_ms.get_nframes(), 5)
self.assertEqual(self.systems.get_nframes(), 5)

def test_str(self):
self.assertEqual(str(self.ms), "MultiSystems (5 systems containing 5 frames)")
self.assertEqual(
str(self.place_holder_ms), "MultiSystems (2 systems containing 5 frames)"
)
self.assertEqual(
str(self.systems), "MultiSystems (5 systems containing 5 frames)"
)


class TestMixedMultiSystemsTypeChange(
unittest.TestCase, CompLabeledSys, MultiSystems, IsNoPBC
):
Expand Down