-
Notifications
You must be signed in to change notification settings - Fork 332
Add build replacement library to the compress algorithm. #616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
58 commits
Select commit
Hold shift + click to select a range
694c317
Add decilm modelling code
danielkorzekwa 991659f
Add decilm modelling code.
danielkorzekwa 8489cee
Add transformers codebase
danielkorzekwa f0afefe
Add transformers code
danielkorzekwa b3ed5bc
Add decilm modelling code
danielkorzekwa a700da5
Add decilm modelling code
danielkorzekwa b59b679
Correct licence headers
danielkorzekwa 1abdf3e
Correct licence headers
danielkorzekwa 66609b1
Add decilm code
danielkorzekwa 7da0a8a
Add decilm code
danielkorzekwa 6e09a81
Add decilm code
danielkorzekwa 2e3f5da
Add decilm code
danielkorzekwa 418890e
Add decilm code
danielkorzekwa 01f4fc1
Make llama3 converter self-contained (no deps on internal Nvidia code)
danielkorzekwa c57eed4
Add common module
danielkorzekwa 3dc37b3
module refactoring
danielkorzekwa 10ffdfe
refactoring
danielkorzekwa 27a4456
add shared_checkpointing_utils
danielkorzekwa b0e22b7
Add json tools
danielkorzekwa 52e7827
add logger
danielkorzekwa f5c1c87
import refactoring
danielkorzekwa 0aa6320
add post_init_sparse module
danielkorzekwa 35d0dbc
Add post_init_sparse
danielkorzekwa e39a1ad
merginy hydra.py and hydra_utils.py
danielkorzekwa 1bd0c67
Add integrationt test for attention pruning
danielkorzekwa 0ecd52b
add score_pruning_activations
danielkorzekwa 278c6b7
import refactoring
danielkorzekwa 7a0af16
add dist_utils
danielkorzekwa 0f0cbbd
Add validate_model
danielkorzekwa cb5cf25
Add activation scoring hooks for pruning
danielkorzekwa 6f82a67
make validate_model self-contained
danielkorzekwa a87fb79
updage validatete_pipeline to use DeciLMForCausalLM from modelopt
danielkorzekwa b227521
fix imports
danielkorzekwa ca7ab3f
add sewing_kit
danielkorzekwa a7a4adc
add sewing_kit
danielkorzekwa ad84c26
fix imports
danielkorzekwa 3d7e8a2
fix imports
danielkorzekwa 3d755b2
add pruning_ckpts
danielkorzekwa 845d453
add pruning_ckpts
danielkorzekwa 4fd921b
import refactoring
danielkorzekwa 3641847
refactor imports
danielkorzekwa 8d6333b
import refactoring
danielkorzekwa dcb86e2
Add build_replacement_library
danielkorzekwa dfd3adc
import refactoring
danielkorzekwa daf94d3
add replacement_library
danielkorzekwa ab6e9e3
refactor imports
danielkorzekwa 01a6aee
refactor imports
danielkorzekwa d6a0fb1
Merge branch 'feature/compress' into dkorzekwa/build_library
danielkorzekwa 3a6b857
Delete not needed mistral tokenizer.
danielkorzekwa dcf7e9e
Add doc strings and remove not used imports
danielkorzekwa 1bfbe14
Delete empty module
danielkorzekwa 6bacaf7
Add doc string.
danielkorzekwa 925365b
Add doc string
danielkorzekwa 1487876
Improve doc string.
danielkorzekwa 5cf9b06
improve doc string
danielkorzekwa 108c187
Add pandas to 'compress' dependencies
danielkorzekwa 1d0b1f7
Replace frozendict with immutabledict
danielkorzekwa 4d38b53
Add immutabledict to compress dependencies.
danielkorzekwa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| #!/usr/bin/env python3 | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| Unified command that runs build_replacement_library followed by calc_subblock_stats. | ||
|
|
||
| This script combines the functionality of both commands into a single workflow: | ||
| 1. First, it builds the replacement library for the puzzle | ||
| 2. Then, it calculates subblock statistics | ||
|
|
||
| Usage: | ||
|
|
||
| python modelopt.torch._compress.build_library_and_stats.py --config-dir configs --config-name Llama-3_1-8B puzzle_dir=/path/to/puzzle/dir dataset_path=/path/to/dataset | ||
|
|
||
| The script uses the same Hydra configuration as the individual commands and supports | ||
| all the same configuration parameters for both build_replacement_library and calc_subblock_stats. | ||
| """ | ||
|
|
||
| import hydra | ||
| from calc_subblock_stats import launch_calc_subblock_stats | ||
| from omegaconf import DictConfig | ||
|
|
||
| from modelopt.torch._compress.replacement_library.build_replacement_library import ( | ||
| launch_build_replacement_library, | ||
| ) | ||
| from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers | ||
| from modelopt.torch._compress.tools.logger import mprint | ||
| from modelopt.torch._compress.utils.parsing import format_global_config | ||
|
|
||
|
|
||
| def launch_build_library_and_stats(cfg: DictConfig) -> None: | ||
| """ | ||
| Launch both build_replacement_library and calc_subblock_stats in sequence. | ||
|
|
||
| Args: | ||
| cfg: Hydra configuration containing settings for both commands | ||
| """ | ||
| mprint("=" * 80) | ||
| mprint("STARTING UNIFIED BUILD LIBRARY AND STATS WORKFLOW") | ||
| mprint("=" * 80) | ||
|
|
||
| # Step 1: Build replacement library | ||
| mprint("=" * 50) | ||
| mprint("STEP 1: Building Replacement Library") | ||
| mprint("=" * 50) | ||
|
|
||
| try: | ||
| launch_build_replacement_library(cfg) | ||
| mprint("✅ Replacement library built successfully!") | ||
| except Exception as e: | ||
| mprint(f"❌ Failed to build replacement library: {e}") | ||
| raise | ||
|
|
||
| # Step 2: Calculate subblock statistics | ||
| mprint("=" * 50) | ||
| mprint("STEP 2: Calculating Subblock Statistics") | ||
| mprint("=" * 50) | ||
|
|
||
| try: | ||
| launch_calc_subblock_stats(cfg) | ||
| mprint("✅ Subblock statistics calculated successfully!") | ||
| except Exception as e: | ||
| mprint(f"❌ Failed to calculate subblock statistics: {e}") | ||
| raise | ||
|
|
||
| mprint("=" * 80) | ||
| mprint("UNIFIED WORKFLOW COMPLETED SUCCESSFULLY! 🎉") | ||
| mprint("=" * 80) | ||
|
|
||
| mprint("Generated files:") | ||
| mprint(f" - {cfg.puzzle_dir}/block_library.json") | ||
| mprint(f" - {cfg.puzzle_dir}/subblock_library.json") | ||
| mprint(f" - {cfg.puzzle_dir}/replacement_library.json") | ||
| mprint(f" - {cfg.puzzle_dir}/single_sequence_replacement_solutions.json") | ||
| mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.subblock_stats_filename}") | ||
| if hasattr(cfg.calc_subblock_stats, "moe_stats_filename"): | ||
| mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.moe_stats_filename}") | ||
|
|
||
|
|
||
| @hydra.main("", version_base="1.3") | ||
| def main(cfg: DictConfig) -> None: | ||
| """ | ||
| Main entry point for the unified build library and stats command. | ||
|
|
||
| This function uses Hydra for configuration management and runs both | ||
| build_replacement_library and calc_subblock_stats in sequence. | ||
| """ | ||
| cfg = hydra.utils.instantiate(cfg) | ||
| mprint("Unified Build Library and Stats Configuration:") | ||
| mprint(format_global_config(cfg)) | ||
| launch_build_library_and_stats(cfg) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| register_hydra_resolvers() | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did this need import
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to access build_library_and_stats