Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
c758ad5
The main compression function for a model
danielkorzekwa Oct 27, 2025
8af9903
Code formatting
danielkorzekwa Oct 27, 2025
5ba6c27
Model search space configuration used by test_compress.py test.
danielkorzekwa Oct 27, 2025
0bc5d84
Tokenizer used by test_compress.py test.
danielkorzekwa Oct 27, 2025
87d4fa5
Tokenizer utility used by test_compress.py test
danielkorzekwa Oct 27, 2025
ced1e99
e2e tests for compress.py
danielkorzekwa Oct 27, 2025
5de0bdc
Add convert_llama3_config_to_decilm_config + unit test
danielkorzekwa Oct 27, 2025
800414c
Remove unused bypass distillation config files.
danielkorzekwa Oct 27, 2025
16abcc9
Moving integration tests to tests/experimental to not trigger CICD
danielkorzekwa Oct 27, 2025
a5ba1c7
update docs
danielkorzekwa Oct 27, 2025
1bda391
Replace mprint with print and replace osp.join with path1 / path2 not…
danielkorzekwa Oct 27, 2025
bb38401
Refactor file checking assertions to use .is_file() and .exists()
danielkorzekwa Oct 27, 2025
8415548
Add a new dependency section to setyp.py for the modelopt.torch._comp…
danielkorzekwa Oct 27, 2025
b1b1833
Move test_convert_llama3_config_to_decilm_config.py to tests/experime…
danielkorzekwa Oct 27, 2025
d4ffc91
Merge branch 'feature/compress' into dkorzekwa/e2e_compression_test
kevalmorabia97 Oct 27, 2025
6f28e4a
Fix: Add missing LICENSE headers
kevalmorabia97 Oct 27, 2025
016fb63
Use spawn_multiprocess_job for test_compress test (to be able to use …
danielkorzekwa Oct 28, 2025
0ccf1c4
Add comments.
danielkorzekwa Oct 28, 2025
58439ca
Add _save_dummy_dataset to the test_compress.py
danielkorzekwa Oct 28, 2025
2e5f776
Refactoring: Move torch distributed env variables to dist_utils.py
danielkorzekwa Oct 28, 2025
6274db5
Refactoring: move torch distributed variables to dist_utils
danielkorzekwa Oct 28, 2025
d942e0a
Move os.environ["WANDB_DISABLED"] = "true" to dist_utils.py
danielkorzekwa Oct 28, 2025
f765921
Implement integration test for mnt.convert() for the _compress algori…
danielkorzekwa Oct 28, 2025
de876d6
Implement mtn.convert() for compress() algorithm.
danielkorzekwa Oct 28, 2025
72bdc7a
Merge branch 'dkorzekwa/e2e_compression_test' into dkorzekwa/llama3_t…
danielkorzekwa Oct 28, 2025
40d28af
Merge branch 'dkorzekwa/llama3_to_decilm_convertion' into dkorzekwa/n…
danielkorzekwa Oct 28, 2025
f7fe23c
Fix broken test - incorrect package names.
danielkorzekwa Oct 28, 2025
3d1d286
Merge branch 'dkorzekwa/llama3_to_decilm_convertion' into dkorzekwa/n…
danielkorzekwa Oct 28, 2025
a210483
Implementing nas.convert for compress algorithm.
danielkorzekwa Oct 28, 2025
739f868
Improve docs
danielkorzekwa Oct 28, 2025
b06d22b
Merge branch 'dkorzekwa/e2e_compression_test' into dkorzekwa/llama3_t…
danielkorzekwa Oct 28, 2025
9352978
Merge branch 'dkorzekwa/llama3_to_decilm_convertion' into dkorzekwa/n…
danielkorzekwa Oct 28, 2025
20a3c5e
Code cleanup.
danielkorzekwa Oct 28, 2025
18cb88b
Merge branch 'feature/compress' into dkorzekwa/llama3_to_decilm_conve…
danielkorzekwa Oct 28, 2025
1033c81
Fix import
danielkorzekwa Oct 28, 2025
0680c45
simplify code
danielkorzekwa Oct 29, 2025
2d9da30
implementing compress_nas_plugin
danielkorzekwa Oct 29, 2025
febab44
code clean up.
danielkorzekwa Oct 29, 2025
86bf394
code clean up
danielkorzekwa Oct 29, 2025
86e04a0
create conftest.py with shared test logic for compress tests.
danielkorzekwa Oct 29, 2025
ae61644
code cleanup
danielkorzekwa Oct 29, 2025
2998cdb
Merge branch 'dkorzekwa/llama3_to_decilm_convertion' into dkorzekwa/n…
danielkorzekwa Oct 29, 2025
3778ec2
code refactoring
danielkorzekwa Oct 29, 2025
d940000
refactoring
danielkorzekwa Oct 29, 2025
0bf9a92
move test utilities from conftest.py to test_utils.py
danielkorzekwa Oct 29, 2025
b56df9a
Improve comments
danielkorzekwa Oct 29, 2025
fd63130
Merge branch 'feature/compress' into dkorzekwa/nas_convert
danielkorzekwa Oct 29, 2025
9bfcc21
Added TODO.
danielkorzekwa Oct 29, 2025
6504c44
Utilitities for hydra initialization
danielkorzekwa Oct 30, 2025
d0fb8f9
Code refactoring
danielkorzekwa Oct 30, 2025
40f18b2
code refactoring
danielkorzekwa Oct 30, 2025
936556f
Add compress dependencies to setup.py.
danielkorzekwa Oct 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions modelopt/torch/_compress/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
from omegaconf import DictConfig
from puzzle_tools.runtime import IRuntime

# TODO Move initialize_hydra_config_for_dir from tests to main
from tests.utils.test_utils import initialize_hydra_config_for_dir
from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir


def compress(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#!/usr/bin/env python3
from pathlib import Path

import torch
from fire import Fire
from puzzle_tools.checkpoint_utils import copy_tokenizer
from puzzle_tools.checkpoint_utils_hf import copy_deci_lm_hf_code
Expand Down Expand Up @@ -46,7 +47,7 @@ def convert_llama3_config_to_decilm_config(config: LlamaConfig) -> DeciLMConfig:
dtype = getattr(config, "torch_dtype", None)

# Convert torch.dtype to string if needed (for JSON serialization)
if dtype is not None and hasattr(dtype, "__module__") and "torch" in dtype.__module__:
if dtype is not None and isinstance(dtype, torch.dtype):
dtype = str(dtype).replace("torch.", "")

# Track which global values will be removed (moved to per-layer configs)
Expand Down
54 changes: 54 additions & 0 deletions modelopt/torch/_compress/hydra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from hydra import compose, initialize, initialize_config_dir
from omegaconf import DictConfig, OmegaConf
Comment on lines +16 to +17
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add hydra and omegaconf to setup.py puzzle dependencies

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, do we have some integration test that will set setup.py? for:

 # Dependedencies for modelopt.torch._compress subpackage
    "compress": [
        "fire",
        "hydra-core==1.3.2",
        "omegaconf==2.3.0",
    ],

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will work on the integration test next week


"""
Utilities for hydra config initialization.
"""


def initialize_hydra_config_for_dir(
config_dir: str, config_name: str, overrides: list[str]
) -> DictConfig:
"""Initialize a hydra config from an absolute path for a config directory

Args:
config_dir (str):
config_name (str):
overrides (List[str]):

Returns:
DictConfig:
"""

with initialize_config_dir(version_base=None, config_dir=config_dir):
args = compose(config_name, overrides)
args._set_flag("allow_objects", True)
OmegaConf.resolve(args) # resolve object attributes
OmegaConf.set_struct(args, False)

return args


def initialize_hydra_config(config_path: str, config_name: str, overrides: list[str]) -> DictConfig:
with initialize(version_base=None, config_path=config_path):
args = compose(config_name, overrides)
args._set_flag("allow_objects", True)
OmegaConf.resolve(args) # resolve object attributes
OmegaConf.set_struct(args, False)

return args
167 changes: 167 additions & 0 deletions modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Compress NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146).
"""

import datetime
from pathlib import Path

import pruning_ckpts
import score_pruning_activations
import torch
from scripts.convert_llama3_to_decilm import convert_llama3_to_decilm
from torch import nn

from modelopt.torch._compress.hydra import initialize_hydra_config_for_dir
from modelopt.torch._compress.runtime import NativeDdpRuntime
from modelopt.torch.nas.conversion import NASModeRegistry
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
from modelopt.torch.opt.mode import (
ConvertEntrypoint,
ConvertReturnType,
MetadataDict,
ModeDescriptor,
RestoreEntrypoint,
)
from modelopt.torch.opt.searcher import BaseSearcher


class CompressModel(nn.Module):
pass # No model implementation is needed for the compress mode


class CompressConfig(ModeloptBaseConfig):
"""Configuration for Compress NAS algorithm."""

# Input model path to compress in the HF format
input_model_path: str = ModeloptField(
default="",
title="",
description="",
)

# Hydra config directory containing the search space definition
hydra_config_dir: str = ModeloptField(
default="",
title="",
description="",
)

# Hydra config name containing the search space definition
hydra_config_name: str = ModeloptField(
default="",
title="",
description="",
)

# Directory to save the compressed model and intermediate results
puzzle_dir: str = ModeloptField(
default="",
title="",
description="",
)

# Dataset path to use for scoring in prunining and NAS search
dataset_path: str = ModeloptField(
default="",
title="",
description="",
)


def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertReturnType:
"""1. Convert the model from HF format to DeciLM format.
2. Score the pruning activations.
3. Prune the model and save pruned checkpoints

The output of this step will be used by mnt.search() to perform the NAS search.
"""
runtime = NativeDdpRuntime(
dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)
)

# Load hydra config
hydra_cfg = initialize_hydra_config_for_dir(
config_dir=config.hydra_config_dir,
config_name=config.hydra_config_name,
overrides=[
f"puzzle_dir={config.puzzle_dir}",
f"dataset_path={config.dataset_path}",
],
)

# Convert Llama3 model to DeciLM model
hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable
convert_llama3_to_decilm(
input_dir=config.input_model_path,
output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir,
)

# Score_pruning_activations (distributed processing)
score_pruning_activations.launch_score_activations(hydra_cfg, runtime)

# Prune the model and save pruned checkpoints
if runtime.global_rank == 0:
pruning_ckpts.launch_prune_ckpt(hydra_cfg)
runtime.wait_for_everyone()

return model, {}


def restore_compress_model(
model: nn.Module, config: CompressConfig, metadata: MetadataDict
) -> nn.Module:
"""Restore is not needed for the compress mode as we are not saving any model state"""
return model


@NASModeRegistry.register_mode
class CompressDescriptor(ModeDescriptor):
"""Descriptor for the Compress mode."""

@property
def name(self) -> str:
"""String identifier for this mode."""
return "compress"

@property
def config_class(self) -> type[ModeloptBaseConfig]:
"""Configuration class for this mode."""
return CompressConfig

@property
def search_algorithm(self) -> type[BaseSearcher]:
"""Return the associated searcher implementation."""
raise NotImplementedError("Compress mode does not have a search algorithm yet.")

@property
def convert(self) -> ConvertEntrypoint:
"""Entrypoint to convert a model."""
return convert_compress_model

@property
def restore(self) -> RestoreEntrypoint:
"""Entrypoint to restore a model."""
return restore_compress_model

@property
def export_mode(self) -> str | None:
"""The mode that corresponds to the export mode.
For now, this will be a no-op as there is no modelopt's concept of search space defined
for the compress algorithm.
"""
return "export_nas"
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@
"setuptools-scm>=8",
],
# Dependedencies for modelopt.torch._compress subpackage
"compress": ["fire"],
"compress": [
"fire",
"hydra-core==1.3.2",
"omegaconf==2.3.0",
],
}

# create "compound" optional dependencies
Expand Down
119 changes: 119 additions & 0 deletions tests/experimental/torch/_compress/compress_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
from pathlib import Path

import torch
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase


def create_and_save_small_llama_model(
output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase
):
"""
Create and save a small Llama model for testing the conversion pipeline.
This mimics having a real Llama checkpoint that needs to be converted.
"""
os.makedirs(output_path, exist_ok=True)

# Create a minimal Llama config (small for testing)
# Note: intermediate_size must be divisible by 256 per DeciLM config requirements
# Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility
llama_config = LlamaConfig(
vocab_size=vocab_size,
hidden_size=256, # 32 heads times 8 head_dim = 256 (matches bypass config expectations)
intermediate_size=512, # Must be divisible by 256
num_hidden_layers=2,
num_attention_heads=32, # Matches original test
num_key_value_heads=8, # GQA: 32÷4=8 (matches original n_heads_in_group=4)
max_position_embeddings=512,
rms_norm_eps=1e-5,
rope_theta=10000.0,
attention_bias=False,
hidden_act="silu",
tie_word_embeddings=False,
)

# Create and save the Llama model
model = LlamaForCausalLM(llama_config)
model.to(dtype=torch.bfloat16).save_pretrained(output_path)

# Save tokenizer
tokenizer.save_pretrained(output_path)

# Save config
llama_config.save_pretrained(output_path)


def create_tokenizer(project_root_path: Path) -> PreTrainedTokenizerBase:
"""
Create a tokenizer for the Llama model.
"""
tokenizer_path = project_root_path / "tests/experimental/torch/_compress/resources/tokenizer"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
return tokenizer


def setup_puzzle_dir(puzzle_dir: str):
"""
Setup puzzle directory by removing existing directory and creating a new one.
"""
if Path(puzzle_dir).exists():
shutil.rmtree(puzzle_dir)
Path(puzzle_dir).mkdir(parents=True, exist_ok=True)


def save_dummy_dataset(dataset_path: str):
"""
Save a dummy dataset for testing purposes.
"""
# dummy sample
sample = [
{"role": "user", "content": "please cite Lorem Ipsum?"},
{
"role": "assistant",
"content": (
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed in blandit ante. "
"Sed tempus erat urna, ac elementum nisl facilisis quis. Aliquam consectetur mollis massa, "
"in elementum sem venenatis posuere. Fusce lorem arcu, egestas vel massa sollicitudin, "
"dictum mollis purus. Proin in ullamcorper elit. Nam tellus nisi, volutpat a mattis vel, "
"pretium in purus. Nunc at lectus facilisis risus scelerisque rhoncus eu nec ex. "
"Maecenas semper, tellus non placerat vulputate, urna felis facilisis diam, "
"sit amet vestibulum erat sapien nec libero. Praesent non massa velit. Donec faucibus mi eros. "
"Nam turpis nulla, congue sit amet mi at, porttitor scelerisque elit. Nunc id sodales lorem, "
"nec tincidunt leo. Quisque a neque nec ligula porttitor auctor. "
"Nunc accumsan nunc ac tellus congue vehicula. Praesent tellus eros, luctus non gravida dapibus, "
"faucibus eu ex. Quisque bibendum leo pharetra, tristique est vitae, hendrerit nunc. "
"Duis nec congue dolor. Donec commodo ipsum non efficitur volutpat. "
"Nulla risus nulla, efficitur et urna at, imperdiet sodales lorem. "
"Suspendisse erat est, sollicitudin at nisl tincidunt, vehicula hendrerit lectus. "
"Nam quis nisi ullamcorper, rhoncus massa vel, tempus purus. "
"Duis pulvinar eros vel nulla pellentesque, at dapibus justo laoreet. "
"Praesent tortor orci, vulputate fermentum dapibus nec, feugiat vitae tortor. "
"Donec mollis convallis massa quis iaculis."
),
},
]

# Prepare train and val splits with sample repeated, 2500 samples are for
# 128 samples with block-size 8192 and LLama3 tokenizer
data = [{"conversation": sample}] * 2500

# For train-val splits
data_dict = DatasetDict({"train": Dataset.from_list(data), "valid": Dataset.from_list(data)})
data_dict.save_to_disk(dataset_path)
Loading