Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
c758ad5
The main compression function for a model
danielkorzekwa Oct 27, 2025
8af9903
Code formatting
danielkorzekwa Oct 27, 2025
5ba6c27
Model search space configuration used by test_compress.py test.
danielkorzekwa Oct 27, 2025
0bc5d84
Tokenizer used by test_compress.py test.
danielkorzekwa Oct 27, 2025
87d4fa5
Tokenizer utility used by test_compress.py test
danielkorzekwa Oct 27, 2025
ced1e99
e2e tests for compress.py
danielkorzekwa Oct 27, 2025
800414c
Remove unused bypass distillation config files.
danielkorzekwa Oct 27, 2025
16abcc9
Moving integration tests to tests/experimental to not trigger CICD
danielkorzekwa Oct 27, 2025
a5ba1c7
update docs
danielkorzekwa Oct 27, 2025
1bda391
Replace mprint with print and replace osp.join with path1 / path2 not…
danielkorzekwa Oct 27, 2025
bb38401
Refactor file checking assertions to use .is_file() and .exists()
danielkorzekwa Oct 27, 2025
d4ffc91
Merge branch 'feature/compress' into dkorzekwa/e2e_compression_test
kevalmorabia97 Oct 27, 2025
6f28e4a
Fix: Add missing LICENSE headers
kevalmorabia97 Oct 27, 2025
016fb63
Use spawn_multiprocess_job for test_compress test (to be able to use …
danielkorzekwa Oct 28, 2025
0ccf1c4
Add comments.
danielkorzekwa Oct 28, 2025
58439ca
Add _save_dummy_dataset to the test_compress.py
danielkorzekwa Oct 28, 2025
2e5f776
Refactoring: Move torch distributed env variables to dist_utils.py
danielkorzekwa Oct 28, 2025
6274db5
Refactoring: move torch distributed variables to dist_utils
danielkorzekwa Oct 28, 2025
d942e0a
Move os.environ["WANDB_DISABLED"] = "true" to dist_utils.py
danielkorzekwa Oct 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions modelopt/torch/_compress/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Experimental model compression algorithm based on a Local Neural Architecture Search.
Based on the Puzzle paper: <https://arxiv.org/abs/2411.19146>
PoC for Llama 3.1 model.
82 changes: 82 additions & 0 deletions modelopt/torch/_compress/compress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""

This module provides the main compression function for a model
using MIP-based NAS search algorithm.

"""

import build_library_and_stats
import mip_and_realize_models
import pruning_ckpts
import score_pruning_activations
import scoring
from omegaconf import DictConfig
from puzzle_tools.runtime import IRuntime

# TODO Move initialize_hydra_config_for_dir from tests to main
from tests.utils.test_utils import initialize_hydra_config_for_dir


def compress(
hydra_config_dir: str, hydra_config: str, puzzle_dir: str, dataset_path: str, runtime: IRuntime
) -> DictConfig:
"""Compress a puzzletron model using the MIP-based NAS search algorithm.

Args:
hydra_config_dir (str): path to a hydra_config_dir that defines the search space
hydra_config (str): the corresponding hydra config file
puzzle_dir (str): directory with a puzzletron model to compress
dataset_path (str): dataset used for scoring and distillation
runtime: distributed runtime to use to run the compression steps, e.g.,
NativeDdpRuntime(dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10))

Returns:
Hydra config object after compressing the model.
The same hydra configuration object is used across all compression steps.
@TODO: Investigate if this config object is immutable across steps and clarify
"""
# Step 0: Load puzzletron hydra config
hydra_cfg = initialize_hydra_config_for_dir(
config_dir=hydra_config_dir,
config_name=hydra_config,
overrides=[
f"puzzle_dir={puzzle_dir}",
f"dataset_path={dataset_path}",
],
)

# Step 1: score_pruning_activations (distributed processing)
score_pruning_activations.launch_score_activations(hydra_cfg, runtime)

# Step 2: pruning_ckpts (single process)
if runtime.global_rank == 0:
pruning_ckpts.launch_prune_ckpt(hydra_cfg)
runtime.wait_for_everyone()

# Step 4: build_library_and_stats (single process)
if runtime.global_rank == 0:
build_library_and_stats.launch_build_library_and_stats(hydra_cfg)
runtime.wait_for_everyone()

# Step 5: calc_one_block_scores (distributed processing)
scoring.launch_scoring(hydra_cfg, runtime)

# Step 6: mip_and_realize_models (distributed processing)
mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime)

return hydra_cfg
Loading