diff --git a/NDP-HNN/train.py b/NDP-HNN/train.py index d8f12db..6bae995 100644 --- a/NDP-HNN/train.py +++ b/NDP-HNN/train.py @@ -31,10 +31,10 @@ def train_model(model, #--- forward one snapshot state, pred_xyz, inc_logits = model(data, state) - #--- mask nodes that are alive at next time step (t+1) + #--- mask nodes that are alive at current time step (t) t = int(data.t[0].item()) mask_next = torch.tensor( - [birth_times[c] <= (t + 1) for c in cells], + [birth_times[c] <= t for c in cells], dtype=torch.bool, device=device ) target_xyz = torch.tensor(birth_feat[:, :3], device=device)[mask_next]