Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
694c317
Add decilm modelling code
danielkorzekwa Nov 3, 2025
991659f
Add decilm modelling code.
danielkorzekwa Nov 3, 2025
8489cee
Add transformers codebase
danielkorzekwa Nov 3, 2025
f0afefe
Add transformers code
danielkorzekwa Nov 3, 2025
b3ed5bc
Add decilm modelling code
danielkorzekwa Nov 3, 2025
a700da5
Add decilm modelling code
danielkorzekwa Nov 3, 2025
b59b679
Correct licence headers
danielkorzekwa Nov 4, 2025
1abdf3e
Correct licence headers
danielkorzekwa Nov 4, 2025
66609b1
Add decilm code
danielkorzekwa Nov 4, 2025
7da0a8a
Add decilm code
danielkorzekwa Nov 4, 2025
6e09a81
Add decilm code
danielkorzekwa Nov 4, 2025
2e3f5da
Add decilm code
danielkorzekwa Nov 4, 2025
418890e
Add decilm code
danielkorzekwa Nov 4, 2025
01f4fc1
Make llama3 converter self-contained (no deps on internal Nvidia code)
danielkorzekwa Nov 4, 2025
c57eed4
Add common module
danielkorzekwa Nov 4, 2025
3dc37b3
module refactoring
danielkorzekwa Nov 4, 2025
10ffdfe
refactoring
danielkorzekwa Nov 5, 2025
27a4456
add shared_checkpointing_utils
danielkorzekwa Nov 5, 2025
b0e22b7
Add json tools
danielkorzekwa Nov 5, 2025
52e7827
add logger
danielkorzekwa Nov 5, 2025
f5c1c87
import refactoring
danielkorzekwa Nov 5, 2025
0aa6320
add post_init_sparse module
danielkorzekwa Nov 5, 2025
35d0dbc
Add post_init_sparse
danielkorzekwa Nov 5, 2025
e39a1ad
merginy hydra.py and hydra_utils.py
danielkorzekwa Nov 5, 2025
3f0772b
Merge branch 'feature/compress' into dkorzekwa/llama_converter_selfco…
danielkorzekwa Nov 13, 2025
872d6c3
Delete not used tokenizer
danielkorzekwa Nov 13, 2025
eb60a1c
Refactor imports
danielkorzekwa Nov 14, 2025
c1533fa
Improve comments + move llama convert pytest from experimental to gpu…
danielkorzekwa Nov 14, 2025
763a4d5
fix broken integration test
danielkorzekwa Nov 14, 2025
76df47c
Improve pydocs
danielkorzekwa Nov 14, 2025
298050d
Remove try except around import omegaconf
danielkorzekwa Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modelopt/torch/_compress/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from omegaconf import DictConfig
from puzzle_tools.runtime import IRuntime

from modelopt.torch._compress.tools.hydra import initialize_hydra_config_for_dir
from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir


def compress(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

"""
Compress NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146).

It is used by mtn.convert() to convert a model from HF format to DeciLM format + do pruning scoring
and save pruned checkpoints, and by mtn.search() to perform the MIP-based NAS search.
"""

import datetime
Expand All @@ -31,7 +34,7 @@
from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import (
convert_llama3_to_decilm,
)
from modelopt.torch._compress.tools.hydra import initialize_hydra_config_for_dir
from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir
from modelopt.torch._compress.tools.logger import mprint
from modelopt.torch._compress.tools.runtime import NativeDdpRuntime
from modelopt.torch.nas.conversion import NASModeRegistry
Expand Down
7 changes: 6 additions & 1 deletion modelopt/torch/_compress/tools/checkpoint_utils.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but do you think it would be better to move this in the DeciLM folder?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe. Actually, initially I moved it to DeciLM but then I realized in how many places it is used. Also not all logic is DeciLM specific. Given that we plan to refactor DeciLM (or even remove) it, I think it is bad timing.

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
# limitations under the License.
# mypy: ignore-errors

"""
It provides general utilities for loading and initializing PyTorch model checkpoints,
particularly for DeciLM models.
"""

import concurrent.futures
import warnings
from functools import partial
Expand Down Expand Up @@ -51,7 +56,7 @@ def load_state_dict(checkpoint_dir: Path | str) -> dict[str, torch.Tensor]:
if (checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME).exists() or (
checkpoint_dir / SAFE_WEIGHTS_NAME
).exists():
from utils.sharded_checkpoint_utils import (
from modelopt.torch._compress.tools.sharded_checkpoint_utils import (
load_sharded_state_dict, # local import to avoid circular import
)

Expand Down
20 changes: 13 additions & 7 deletions modelopt/torch/_compress/tools/checkpoint_utils_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
# limitations under the License.
# mypy: ignore-errors

"""
Provides utilities for loading and saving PyTorch model checkpoints in the Hugging Face format,
particularly for DeciLM models.
"""

import concurrent.futures
import fcntl
import os
Expand All @@ -26,15 +31,16 @@
from typing import Any, BinaryIO

import torch
from logger import mprint
from puzzle_tools import deci_lm_hf_code
from puzzle_tools.common import infer_weights_dtype
from puzzle_tools.deci_lm_hf_code.configuration_decilm import DeciLMConfig
from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM
from puzzle_tools.robust_json import json_dumps
from safetensors.torch import save_file as safe_save_file
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from utils.post_init_sparse import SparsityMethod

from modelopt.torch._compress.decilm import deci_lm_hf_code
from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig
from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM
from modelopt.torch._compress.tools.common import infer_weights_dtype
from modelopt.torch._compress.tools.logger import mprint
from modelopt.torch._compress.tools.post_init_sparse import SparsityMethod
from modelopt.torch._compress.tools.robust_json import json_dumps

SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors"
PTH_SUBBLOCKS_DIR_NAME = "subblocks"
Expand Down
81 changes: 81 additions & 0 deletions modelopt/torch/_compress/tools/hydra_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Utilities for hydra config initialization.
"""

import datetime
import random
from pathlib import Path

from hydra import compose, initialize, initialize_config_dir
from hydra.utils import get_object
from omegaconf import DictConfig, OmegaConf


def warmup_steps(tokens: int, block: int, mbs: int, pct: float = 0.05) -> int:
"""
Calculate warmup steps based on total tokens, block size, micro batch size, and warmup percentage.
Used as a resolver in hydra configs.
"""
steps = (int(tokens) // int(block)) // int(mbs)
w = pct * steps
return max(1, round(w))


def register_hydra_resolvers():
OmegaConf.register_new_resolver("to_path", lambda x: Path(x))
OmegaConf.register_new_resolver(
"random_int", lambda low, high: random.randint(int(low), int(high))
)
OmegaConf.register_new_resolver(
"timedelta_minutes", lambda x: datetime.timedelta(minutes=x) if x is not None else None
)
OmegaConf.register_new_resolver("warmup_steps", lambda t, b, m, p: warmup_steps(t, b, m, p))
OmegaConf.register_new_resolver("get_object", lambda x: get_object(x))


def initialize_hydra_config_for_dir(
config_dir: str, config_name: str, overrides: list[str]
) -> DictConfig:
"""Initialize a hydra config from an absolute path for a config directory

Args:
config_dir (str):
config_name (str):
overrides (List[str]):

Returns:
DictConfig:
"""

with initialize_config_dir(version_base=None, config_dir=config_dir):
args = compose(config_name, overrides)
args._set_flag("allow_objects", True)
OmegaConf.resolve(args) # resolve object attributes
OmegaConf.set_struct(args, False)

return args


def initialize_hydra_config(config_path: str, config_name: str, overrides: list[str]) -> DictConfig:
with initialize(version_base=None, config_path=config_path):
args = compose(config_name, overrides)
args._set_flag("allow_objects", True)
OmegaConf.resolve(args) # resolve object attributes
OmegaConf.set_struct(args, False)

return args
129 changes: 129 additions & 0 deletions modelopt/torch/_compress/tools/post_init_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# mypy: ignore-errors
import torch
from torch import nn
from torch.nn.utils.prune import custom_from_mask

from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM

"""
Converts a state dictionary from PyTorch's pruning format (with _orig and _mask suffixes)
into a standard format with sparsified weights.
"""


class SparsityMethod:
def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
gets a model state_dict, returns a state_dict-like mask_dict with masks
"""

@staticmethod
def fix_state_dict_inplace(state_dict, verbose=False, change_dtype=False):
sparsity_masks = {}
for name in list(state_dict.keys()):
original_name = name.replace("_orig", "")
mask_name = original_name + "_mask"
if name[-4:] == "orig" and mask_name in state_dict:
val = state_dict[name]
mask = state_dict[name[:-4] + "mask"]
val[mask == 0] = 0
sparsity = (val == 0).sum() / mask.numel()
sparsity_masks[original_name[:-7]] = mask
if verbose:
print(f"fix_state_dict_inplace: {name} {sparsity=}")
del state_dict[mask_name]
del state_dict[name]
state_dict[original_name] = val
if change_dtype:
for name in state_dict:
state_dict[name] = state_dict[name].to(torch.bfloat16)
return state_dict, sparsity_masks

def filter_function(self):
pass

def apply_masks(self, model: nn.Module, mask_dict: dict[str, torch.Tensor]) -> None:
for name, module in model.named_modules():
if name in mask_dict:
custom_from_mask(module, "weight", mask_dict[name].to(module.weight.device))
print(name)
print(torch.sum(mask_dict[name]) / mask_dict[name].numel())

def do_sparsity(self, model: DeciLMForCausalLM, mask_dict=None):
full_name_layers = []
for block_idx, block_config in enumerate(model.config.block_configs):
ffn_names = block_config.ffn.sparsify # layers_to_sparsify_pattern[block_idx]
att_name = block_config.attention.sparsify
block = model.model.layers[block_idx]
if hasattr(block, "mlp"):
for name, m in block.mlp.named_modules():
if isinstance(m, torch.nn.Linear) and self.filter_function(name, ffn_names):
full_name_layers.append(
"model.layers." + str(block_idx) + "." + "mlp." + name
)
if hasattr(block, "self_attn"):
for name, m in block.self_attn.named_modules():
if isinstance(m, torch.nn.Linear) and self.filter_function(name, att_name):
full_name_layers.append(
"model.layers." + str(block_idx) + "." + "self_attn." + name
)

if mask_dict is None:
state_dict_for_sparsifying = {
k.rstrip(".weight"): v
for k, v in model.state_dict().items()
if k.rstrip(".weight") in full_name_layers
}
mask_dict = self.calculate_masks(state_dict_for_sparsifying)
# print('Apply sparsity')
# print(full_name_layers)
# print(model.state_dict().keys())
# print(list(mask_dict.keys()))

self.apply_masks(model, mask_dict)


class SparsityMethod2o4(SparsityMethod):
def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
gets a model state_dict, returns a state_dict-like mask_dict with masks
"""
mask_dict = {}
for key, val in state_dict.items():
orig_size = val.shape
scores = val.flatten() ** 2
mask = self.create_mask(scores)
mask = mask.reshape(orig_size)
mask_dict[key] = mask
return mask_dict

def create_mask(self, score, value=0):
score = score # .cpu()
orig_size = score.shape
score = score.view(-1, 4)
mask = torch.zeros(score.shape)
values, indices = torch.topk(score, 2, dim=1)
rows = torch.arange(mask.size(0)).unsqueeze(-1)
mask[rows, indices] = 1
mask = mask.view(orig_size)
return mask # dev = score.device, return mask.to(dev)

@staticmethod
def filter_function(name, modules_to_sparsify_in_block):
if modules_to_sparsify_in_block is None:
return False
return name in modules_to_sparsify_in_block
72 changes: 72 additions & 0 deletions modelopt/torch/_compress/tools/robust_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# mypy: ignore-errors

"""
Provides a robust JSON encoder that can handle various types of objects,
including dataclasses, paths, enums, namespaces, and functions.
"""

import argparse
import dataclasses
import datetime
import inspect
import json
from enum import Enum
from pathlib import Path
from typing import Any

from omegaconf import DictConfig, ListConfig, OmegaConf


class RobustJSONEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
if isinstance(o, Path):
return str(o)
if isinstance(o, Enum):
return o.name
if isinstance(o, argparse.Namespace):
return vars(o)
if type(o).__name__ == "dtype":
return str(o)
if isinstance(o, (DictConfig, ListConfig)):
return OmegaConf.to_container(o, resolve=True)
if inspect.isfunction(o) or inspect.ismethod(o):
if o.__module__ == "__main__":
# User-defined function in main — fallback to just the name
return o.__name__
return f"{o.__module__}.{o.__qualname__}"
if isinstance(o, datetime.timedelta):
return str(o)
return super().default(o)


def json_dumps(obj: Any) -> str:
return json.dumps(obj, cls=RobustJSONEncoder, indent=2)


def json_dump(obj: Any, path: Path | str) -> None:
path = Path(path)
path.parent.mkdir(exist_ok=True, parents=True)
json_text = json_dumps(obj)
path.write_text(json_text)


def json_load(path: Path | str) -> dict:
path = Path(path)
text = path.read_text()
return json.loads(text)
Loading
Loading