Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 164 additions & 22 deletions dpdata/deepmd/mixed.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,128 @@
from __future__ import annotations

import copy
import math

import numpy as np

import dpdata
from dpdata.data_type import Axis

from .comp import dump as comp_dump
from .comp import to_system_data as comp_to_system_data


def _pad_to(sys_data, target_natoms, dtypes):
"""Pad system data dict so that NATOMS dimension becomes target_natoms.

Virtual atoms get real_atom_types = -1, and all other per-atom data is
padded with zeros.

Parameters
----------
sys_data : dict
System data dict, already in mixed-type format.
target_natoms : int
Target number of atoms after padding.
dtypes : tuple[DataType, ...]
Registered data types to iterate for generic per-atom padding.
"""
natoms = sys_data["atom_types"].shape[0]
npad = target_natoms - natoms
if npad <= 0:
return
nframes = sys_data["coords"].shape[0]

# Pad atom_types (all MIXED_TOKEN = 0)
sys_data["atom_types"] = np.concatenate(
[sys_data["atom_types"], np.zeros(npad, dtype=int)]
)
sys_data["atom_numbs"] = [target_natoms]

# Pad real_atom_types with -1 (virtual atom sentinel)
sys_data["real_atom_types"] = np.concatenate(
[
sys_data["real_atom_types"],
-np.ones((nframes, npad), dtype=sys_data["real_atom_types"].dtype),
],
axis=1,
)

# Pad coords and all other per-atom data generically
reserved = {
"atom_numbs",
"atom_names",
"atom_types",
"orig",
"cells",
"real_atom_names",
"real_atom_types",
"nopbc",
}
for dtype in dtypes:
if dtype.name in reserved:
continue
if dtype.name not in sys_data:
continue
if not (
len(dtype.shape) >= 2
and dtype.shape[0] == Axis.NFRAMES
and Axis.NATOMS in dtype.shape
):
continue
axis_natoms = list(dtype.shape).index(Axis.NATOMS)
arr = sys_data[dtype.name]
pad_width = [(0, 0)] * len(arr.shape)
pad_width[axis_natoms] = (0, npad)
sys_data[dtype.name] = np.pad(
arr, pad_width, mode="constant", constant_values=0
)


def _strip_virtual_atoms(atom_types_row, coords, extra_data, dtypes):
"""Strip virtual atoms (type -1) from a group of frames.

Parameters
----------
atom_types_row : np.ndarray
1-D array of atom type indices for the group (same for all frames).
coords : np.ndarray
Coordinates array, shape (nframes, natoms_padded, 3).
extra_data : dict
Dict of {name: array} for this group, arrays already frame-sliced.
dtypes : tuple[DataType, ...]
Registered data types.

Returns
-------
atom_types : np.ndarray
Atom types with virtual atoms removed.
coords : np.ndarray
Coords with virtual atoms removed.
extra_data : dict
Extra data with virtual atoms removed.
"""
real_mask = atom_types_row >= 0
if real_mask.all():
return atom_types_row, coords, extra_data

atom_types = atom_types_row[real_mask]
coords = coords[:, real_mask, :]

stripped = {}
for name, arr in extra_data.items():
for dtype in dtypes:
if dtype.name == name and Axis.NATOMS in dtype.shape:
axis_natoms = list(dtype.shape).index(Axis.NATOMS)
idx = [slice(None)] * len(arr.shape)
idx[axis_natoms] = real_mask
arr = arr[tuple(idx)]
break
stripped[name] = arr

return atom_types, coords, stripped


def to_system_data(folder, type_map=None, labels=True):
data = comp_to_system_data(folder, type_map, labels)
# data is empty
Expand All @@ -26,7 +139,11 @@ def to_system_data(folder, type_map=None, labels=True):
index_map = None
all_real_atom_types_concat = data.pop("real_atom_types").astype(int)
if index_map is not None:
all_real_atom_types_concat = index_map[all_real_atom_types_concat]
# Preserve -1 (virtual atom sentinel) during remapping
valid = all_real_atom_types_concat >= 0
remapped = np.full_like(all_real_atom_types_concat, -1)
remapped[valid] = index_map[all_real_atom_types_concat[valid]]
all_real_atom_types_concat = remapped
all_cells_concat = data["cells"]
all_coords_concat = data["coords"]

Expand Down Expand Up @@ -60,31 +177,44 @@ def to_system_data(folder, type_map=None, labels=True):
while True:
if all_real_atom_types_concat.size == 0:
break
temp_atom_numbs = [
np.count_nonzero(all_real_atom_types_concat[0] == i)
for i in range(len(data["atom_names"]))
]
# temp_formula = formula(data['atom_names'], temp_atom_numbs)
temp_idx = np.arange(all_real_atom_types_concat.shape[0])[
(all_real_atom_types_concat == all_real_atom_types_concat[0]).all(-1)
]
rest_idx = np.arange(all_real_atom_types_concat.shape[0])[
(all_real_atom_types_concat != all_real_atom_types_concat[0]).any(-1)
]

# Extract data for this group
group_atom_types = all_real_atom_types_concat[0]
group_coords = all_coords_concat[temp_idx]
group_extra = {}
for name in extra_data:
group_extra[name] = extra_data[name][temp_idx]
extra_data[name] = extra_data[name][rest_idx]

# Strip virtual atoms (type -1) introduced by padding
group_atom_types, group_coords, group_extra = _strip_virtual_atoms(
group_atom_types, group_coords, group_extra, dtypes
)

temp_atom_numbs = [
np.count_nonzero(group_atom_types == i)
for i in range(len(data["atom_names"]))
]

temp_data = data.copy()
temp_data["atom_names"] = data["atom_names"].copy()
temp_data["atom_numbs"] = temp_atom_numbs
temp_data["atom_types"] = all_real_atom_types_concat[0]
temp_data["atom_types"] = group_atom_types
all_real_atom_types_concat = all_real_atom_types_concat[rest_idx]
temp_data["cells"] = all_cells_concat[temp_idx]
all_cells_concat = all_cells_concat[rest_idx]
temp_data["coords"] = all_coords_concat[temp_idx]
temp_data["coords"] = group_coords
all_coords_concat = all_coords_concat[rest_idx]

for name in extra_data:
all_dtype_concat = extra_data[name]
temp_data[name] = all_dtype_concat[temp_idx]
extra_data[name] = all_dtype_concat[rest_idx]
for name in group_extra:
temp_data[name] = group_extra[name]

data_list.append(temp_data)
return data_list
Expand All @@ -109,7 +239,7 @@ def dump(folder, data, set_size=2000, comp_prec=np.float32, remove_sets=True):
comp_dump(folder, data, set_size, comp_prec, remove_sets)


def mix_system(*system, type_map, **kwargs):
def mix_system(*system, type_map, atom_numb_pad=None, **kwargs):
"""Mix the systems into mixed_type ones according to the unified given type_map.

Parameters
Expand All @@ -118,6 +248,11 @@ def mix_system(*system, type_map, **kwargs):
The systems to mix
type_map : list of str
Maps atom type to name
atom_numb_pad : int, optional
If provided, pad atom counts to the next multiple of this number
using virtual atoms (type -1 in real_atom_types). This reduces the
number of subdirectories when systems have many different atom counts.
For example, atom_numb_pad=8 groups systems into multiples of 8.
**kwargs : dict
Other parameters

Expand All @@ -129,21 +264,28 @@ def mix_system(*system, type_map, **kwargs):
mixed_systems = {}
temp_systems = {}
atom_numbs_frame_index = {} # index of frames in cur sys
# Use LabeledSystem DTYPES as superset for generic per-atom padding
dtypes = dpdata.system.LabeledSystem.DTYPES
for sys in system:
tmp_sys = sys.copy()
natom = tmp_sys.get_natoms()
tmp_sys.convert_to_mixed_type(type_map=type_map)
if str(natom) not in atom_numbs_frame_index:
atom_numbs_frame_index[str(natom)] = 0
atom_numbs_frame_index[str(natom)] += tmp_sys.get_nframes()
if str(natom) not in temp_systems or not temp_systems[str(natom)]:
temp_systems[str(natom)] = tmp_sys
if atom_numb_pad is not None and atom_numb_pad > 1:
padded_natom = math.ceil(natom / atom_numb_pad) * atom_numb_pad
_pad_to(tmp_sys.data, padded_natom, dtypes)
group_key = str(padded_natom)
else:
group_key = str(natom)
if group_key not in atom_numbs_frame_index:
atom_numbs_frame_index[group_key] = 0
atom_numbs_frame_index[group_key] += tmp_sys.get_nframes()
if group_key not in temp_systems or not temp_systems[group_key]:
temp_systems[group_key] = tmp_sys
else:
temp_systems[str(natom)].append(tmp_sys)
for natom in temp_systems:
if atom_numbs_frame_index[natom] > 0:
sys_name = f"{natom}"
mixed_systems[sys_name] = temp_systems[natom]
temp_systems[group_key].append(tmp_sys)
for natom_key in temp_systems:
if atom_numbs_frame_index[natom_key] > 0:
mixed_systems[natom_key] = temp_systems[natom_key]
return mixed_systems


Expand Down
26 changes: 24 additions & 2 deletions dpdata/plugins/deepmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ class DeePMDMixedFormat(Format):
>>> import dpdata
>>> dpdata.MultiSystems(*systems).to_deepmd_npy_mixed("mixed_dir")

Dump with ``atom_numb_pad`` to reduce the number of subdirectories.
Systems are padded with virtual atoms (type -1) so that atom counts are
rounded up to the nearest multiple of the given number:

>>> dpdata.MultiSystems(*systems).to_deepmd_npy_mixed("mixed_dir", atom_numb_pad=8)

Load a mixed type data into a MultiSystems:

>>> import dpdata
Expand Down Expand Up @@ -156,7 +162,7 @@ def from_labeled_system_mix(self, file_name, type_map=None, **kwargs):
file_name, type_map=type_map, labels=True
)

def mix_system(self, *system, type_map, **kwargs):
def mix_system(self, *system, type_map, atom_numb_pad=None, **kwargs):
"""Mix the systems into mixed_type ones according to the unified given type_map.

Parameters
Expand All @@ -165,15 +171,31 @@ def mix_system(self, *system, type_map, **kwargs):
The systems to mix
type_map : list of str
Maps atom type to name
atom_numb_pad : int, optional
If provided, pad atom counts to the next multiple of this number
using virtual atoms (type -1 in real_atom_types). This reduces the
number of subdirectories when systems have many different atom counts.
For example, ``atom_numb_pad=8`` groups systems into multiples of 8:
a 5-atom system is padded to 8, a 9-atom system is padded to 16, etc.
Virtual atoms are transparently removed when loading the data back.
**kwargs : dict
other parameters

Returns
-------
mixed_systems: dict
dict of mixed system with key 'atom_numbs'

Examples
--------
Dump with padding so that atom counts are rounded up to multiples of 8:

>>> import dpdata
>>> dpdata.MultiSystems(*systems).to_deepmd_npy_mixed("mixed_dir", atom_numb_pad=8)
"""
return dpdata.deepmd.mixed.mix_system(*system, type_map=type_map, **kwargs)
return dpdata.deepmd.mixed.mix_system(
*system, type_map=type_map, atom_numb_pad=atom_numb_pad, **kwargs
)

def from_multi_systems(self, directory, **kwargs):
register_spin()
Expand Down
Loading