-
Notifications
You must be signed in to change notification settings - Fork 319
Draft: Merge anymodel pruning #990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e82164f
2099df3
eb5cf8a
c9de41c
3c1bc1f
8357136
6cc2194
ee4e1e3
449b523
fb27bba
b350f82
fafe5a3
e988248
c717852
030f126
8dcdfbf
70df0df
bb56662
ecd953e
ee8f538
c9b76a1
6e3af61
47414d5
a8305d8
68421a5
d6b8028
ecd2341
f9d845d
d171b01
722da90
934ab2f
176a435
02e2c9b
0fc10a1
d9a8647
8398294
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,15 +14,22 @@ | |
| # limitations under the License. | ||
| # mypy: ignore-errors | ||
|
|
||
| """TODO Add description""" | ||
| """Initialize child models from parent models using AnyModel approach with deci_x_patcher.""" | ||
|
|
||
| import json | ||
| import time | ||
| from pathlib import Path | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| import yaml | ||
| from transformers import AutoModelForCausalLM | ||
|
|
||
| from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM | ||
| from modelopt.torch.puzzletron.anymodel.model_descriptor import ( | ||
| ModelDescriptor, | ||
| ModelDescriptorFactory, | ||
| ) | ||
| from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher | ||
| from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( | ||
| GQAInitMode, | ||
| HiddenSizeInitMode, | ||
|
|
@@ -31,85 +38,37 @@ | |
| create_child_state_dict, | ||
| update_model_config, | ||
| ) | ||
| from modelopt.torch.puzzletron.tools.checkpoint_utils import ( | ||
| copy_tokenizer, | ||
| load_model_config, | ||
| load_state_dict, | ||
| ) | ||
| from modelopt.torch.puzzletron.tools.checkpoint_utils import copy_tokenizer, load_state_dict | ||
| from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import ( | ||
| _save_checkpoint, | ||
| copy_deci_lm_hf_code, | ||
| load_model_config, | ||
| ) | ||
| from modelopt.torch.puzzletron.tools.logger import mprint | ||
|
|
||
| """ | ||
|
|
||
| Usage example - remove all/some routed experts: | ||
| =============================================== | ||
|
|
||
| PARENT_DIR=".../meta-llama/Llama-4-Scout-17B-16E-Instruct--deci-hf" | ||
|
|
||
| MLP_INIT_MODE="ConcatExpertsIntoDenseFFN" | ||
|
|
||
| ## remove all routed experts, turn the shared expert into a dense FFN | ||
| # OUTPUT_DIR="/.../micro_scout/Scout-remove-routed-experts" | ||
| # MODEL_CONFIG_OVERRIDES_JSON=' | ||
| # { | ||
| # "ffn": [ | ||
| # { | ||
| # "moe": null, | ||
| # "intermediate_size": 14336, | ||
| # "gated": true, | ||
| # "hidden_act": "silu" | ||
| # } | ||
| # ] | ||
| # } | ||
| # ' | ||
|
|
||
| ## concat the shared expert with one routed expert into a dense FFN | ||
| OUTPUT_DIR=".../scratch/micro_scout/Scout-ConcatExpertsIntoDenseFFN-concat-shared-and-3-routed" | ||
| MODEL_CONFIG_OVERRIDES_JSON=' | ||
| { | ||
| "ffn": [ | ||
| { | ||
| "moe": null, | ||
| "intermediate_size": 14336, | ||
| "gated": true, | ||
| "hidden_act": "silu" | ||
| } | ||
| ] | ||
| } | ||
| ' | ||
|
|
||
| echo "" | ||
| echo "MODEL_CONFIG_OVERRIDES_JSON:" | ||
| echo "${MODEL_CONFIG_OVERRIDES_JSON}" | ||
|
|
||
| python -m modelopt.torch.puzzletron.tools.bypassed_training.init_child_from_parent \ | ||
| --parent_checkpoint_dir="$PARENT_DIR" \ | ||
| --model_config_overrides_json="$MODEL_CONFIG_OVERRIDES_JSON" \ | ||
| --output_checkpoint_dir="$OUTPUT_DIR" \ | ||
| --mlp_init_mode="$MLP_INIT_MODE" \ | ||
| --mlp_init_config_yaml="$MLP_INIT_CONFIG_YAML" | ||
| """ | ||
| from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import _get_model_class_from_config | ||
|
|
||
|
|
||
| def init_child_from_parent( | ||
| descriptor: ModelDescriptor, | ||
| pruning_mixin, | ||
| parent_checkpoint_dir: str, | ||
| model_config_overrides_json: str, | ||
| model_config_overrides_dict: dict | str, | ||
| output_checkpoint_dir: str, | ||
| gqa_init_mode: GQAInitMode, | ||
| mlp_init_mode: MlpInitMode, | ||
| mlp_init_config_yaml: str | None, | ||
| mlp_init_config_yaml: Optional[str], | ||
| linear_init_mode: LinearInitMode, | ||
| hidden_size_init_mode: HiddenSizeInitMode | None = None, | ||
| channel_importance_path: str | None = None, | ||
| max_workers: int | None = None, # Auto-calculate optimal workers if None | ||
| max_layer_workers: int | None = None, # Auto-calculate optimal workers if None | ||
| hidden_size_init_mode: Optional[HiddenSizeInitMode] = None, | ||
| channel_importance_path: Optional[str] = None, | ||
| max_workers: Optional[int] = None, # Auto-calculate optimal workers if None | ||
| max_layer_workers: Optional[int] = None, # Auto-calculate optimal workers if None | ||
| ) -> None: | ||
| """Init child models from parent models in the style of bypass training, | ||
| """ | ||
| Init child models from parent models in the style of bypass training, | ||
| but without having to run the entire bypass pipeline. | ||
|
|
||
| Uses AnyModel approach with deci_x_patcher for heterogeneous layer configurations. | ||
|
|
||
| I/O Optimization Parameters: | ||
| - max_workers: Number of threads for parallel file I/O (default: auto-calculate min(CPU count, num files)) | ||
| - max_layer_workers: Number of threads for parallel layer processing (default: auto-calculate min(CPU count, num layers)) | ||
|
|
@@ -123,16 +82,16 @@ def init_child_from_parent( | |
| "We do not support random init of any subblock in this script to avoid initializing the student model" | ||
| ) | ||
|
|
||
| descriptor = ModelDescriptorFactory.get(descriptor) | ||
|
|
||
| copy_tokenizer(parent_checkpoint_dir, output_checkpoint_dir) | ||
|
|
||
| parent_model_config = load_model_config(parent_checkpoint_dir) | ||
| parent_state_dict = load_state_dict(parent_checkpoint_dir) | ||
|
|
||
| # Parse the model config overrides | ||
| if isinstance(model_config_overrides_json, str): | ||
| model_config_overrides_dict = json.loads(model_config_overrides_json) | ||
| else: | ||
| model_config_overrides_dict = model_config_overrides_json | ||
| # Parse JSON if string | ||
| if isinstance(model_config_overrides_dict, str): | ||
| model_config_overrides_dict = json.loads(model_config_overrides_dict) | ||
|
|
||
| # Separate global config overrides from block-level overrides | ||
| global_config_overrides = {} | ||
|
|
@@ -146,7 +105,7 @@ def init_child_from_parent( | |
|
|
||
| # Load child model config with global overrides | ||
| child_model_config = load_model_config( | ||
| checkpoint_dir=parent_checkpoint_dir, | ||
| parent_checkpoint_dir, | ||
| model_config_overrides=global_config_overrides, | ||
| ignore_unexpected_config_keys=True, | ||
| ) | ||
|
|
@@ -159,19 +118,32 @@ def init_child_from_parent( | |
| ) | ||
|
|
||
| with torch.device("meta"): | ||
| child_model = DeciLMForCausalLM(child_model_config) | ||
| # Pass block_configs explicitly so patcher works for VL models where | ||
| # decoder layers receive nested config (e.g., text_config) without block_configs | ||
| with deci_x_patcher( | ||
| model_descriptor=descriptor, block_configs=child_model_config.block_configs | ||
| ): | ||
| model_class = _get_model_class_from_config(child_model_config) | ||
| # AutoModelForCausalLM uses from_config(); concrete model classes use _from_config() | ||
| if model_class is AutoModelForCausalLM: | ||
| child_model = model_class.from_config(child_model_config, trust_remote_code=True) | ||
| else: | ||
| child_model = model_class._from_config(child_model_config) | ||
|
Comment on lines
+123
to
+131
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: # Find the file and check if it exists
find . -type f -name "init_child_from_parent.py" | head -20Repository: NVIDIA/Model-Optimizer Length of output: 143 🏁 Script executed: # Once found, read the relevant section of the file
cd modelopt/torch/puzzletron/tools/bypassed_training 2>/dev/null && wc -l init_child_from_parent.pyRepository: NVIDIA/Model-Optimizer Length of output: 95 🏁 Script executed: # Read the full function to understand the context and function signature
cat -n modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py | head -150Repository: NVIDIA/Model-Optimizer Length of output: 6974 🏁 Script executed: # Check for inline comments justifying the hardcoded True and look at more context
cat -n modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py | sed -n '120,140p'Repository: NVIDIA/Model-Optimizer Length of output: 1179 🏁 Script executed: # Search for all calls to init_child_from_parent function
rg "init_child_from_parent\(" --type py -B 2 -A 2Repository: NVIDIA/Model-Optimizer Length of output: 2855 🏁 Script executed: # Check if there are other uses of trust_remote_code in the codebase to understand patterns
rg "trust_remote_code" --type pyRepository: NVIDIA/Model-Optimizer Length of output: 22982 🏁 Script executed: # Check the sharded_checkpoint_utils.py for similar issues
cat -n modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py | grep -B 5 -A 5 "trust_remote_code"Repository: NVIDIA/Model-Optimizer Length of output: 756 🏁 Script executed: # Let's see the full context of the call in pruning_ckpts.py to understand the impact
rg "init_child_from_parent\(" modelopt/torch/puzzletron/pruning/pruning_ckpts.py -B 2 -A 15 | head -100Repository: NVIDIA/Model-Optimizer Length of output: 4850 🏁 Script executed: # Verify that there are no inline comments justifying the hardcoded True
sed -n '126,132p' modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.pyRepository: NVIDIA/Model-Optimizer Length of output: 481 Don't hardcode Line 129 turns child-model construction into an RCE boundary for any checkpoint/config that carries custom modeling code. Thread this through as a caller-controlled flag with a safe default of Proposed fix def init_child_from_parent(
descriptor: ModelDescriptor,
pruning_mixin,
parent_checkpoint_dir: str,
model_config_overrides_dict: dict | str,
output_checkpoint_dir: str,
gqa_init_mode: GQAInitMode,
mlp_init_mode: MlpInitMode,
mlp_init_config_yaml: Optional[str],
linear_init_mode: LinearInitMode,
hidden_size_init_mode: Optional[HiddenSizeInitMode] = None,
channel_importance_path: Optional[str] = None,
max_workers: Optional[int] = None, # Auto-calculate optimal workers if None
max_layer_workers: Optional[int] = None, # Auto-calculate optimal workers if None
+ trust_remote_code: bool = False,
) -> None:
@@
if model_class is AutoModelForCausalLM:
- child_model = model_class.from_config(child_model_config, trust_remote_code=True)
+ child_model = model_class.from_config(
+ child_model_config, trust_remote_code=trust_remote_code
+ )
else:
child_model = model_class._from_config(child_model_config)Per coding guidelines: "Do not hardcode 🤖 Prompt for AI Agents |
||
|
|
||
| child_state_dict_with_meta_tensors = child_model.state_dict() | ||
|
|
||
| mlp_init_config = ( | ||
| yaml.safe_load(mlp_init_config_yaml) | ||
| if isinstance(mlp_init_config_yaml, str) is None | ||
| if isinstance(mlp_init_config_yaml, str) | ||
| else mlp_init_config_yaml | ||
| ) | ||
|
|
||
| # Profile create_child_state_dict with automatic layer parallelization | ||
| mprint("Starting create_child_state_dict...") | ||
| start_time = time.time() | ||
| child_state_dict = create_child_state_dict( | ||
| pruning_mixin=pruning_mixin, | ||
| descriptor=descriptor, | ||
| original_state_dict=parent_state_dict, | ||
| new_state_dict=child_state_dict_with_meta_tensors, | ||
| original_config=parent_model_config, | ||
|
|
@@ -182,7 +154,7 @@ def init_child_from_parent( | |
| linear_init_mode=linear_init_mode, | ||
| hidden_size_init_mode=hidden_size_init_mode or HiddenSizeInitMode.CopyAsIs, | ||
| channel_importance_path=channel_importance_path, | ||
| max_layer_workers=max_layer_workers, # Will auto-calculate if None | ||
| max_layer_workers=max_layer_workers, | ||
| ) | ||
| create_child_state_dict_time = time.time() - start_time | ||
| mprint(f"create_child_state_dict completed in {create_child_state_dict_time:.2f} seconds") | ||
|
|
@@ -196,7 +168,8 @@ def init_child_from_parent( | |
| child_model_config, | ||
| child_state_dict, | ||
| output_checkpoint_dir, | ||
| max_workers=max_workers, # Will auto-calculate if None | ||
| descriptor, | ||
| max_workers=max_workers, | ||
| ) | ||
| save_checkpoint_time = time.time() - start_time | ||
| mprint(f"_save_checkpoint completed in {save_checkpoint_time:.2f} seconds") | ||
|
|
@@ -207,7 +180,7 @@ def init_child_from_parent( | |
| total_core_time = create_child_state_dict_time + save_checkpoint_time | ||
| actual_layer_workers = max_layer_workers if max_layer_workers else "auto" | ||
| actual_io_workers = max_workers if max_workers else "auto" | ||
| mprint("\n=== PROFILING SUMMARY ===") | ||
| mprint(f"\n=== PROFILING SUMMARY ===") | ||
| mprint( | ||
| f"create_child_state_dict: {create_child_state_dict_time:.2f}s ({create_child_state_dict_time / total_core_time * 100:.1f}%)" | ||
| ) | ||
|
|
@@ -216,4 +189,4 @@ def init_child_from_parent( | |
| ) | ||
| mprint(f"Total core processing: {total_core_time:.2f}s") | ||
| mprint(f"Optimizations: I/O workers={actual_io_workers}, Layer workers={actual_layer_workers}") | ||
| mprint("=========================\n") | ||
| mprint(f"=========================\n") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reject non-object JSON overrides early.
If the string decodes to anything other than a mapping, Line 100 fails later with a generic
'... has no attribute items'. Validate the parsed type here and raise a clearer error at the boundary.Proposed fix
if isinstance(model_config_overrides_dict, str): model_config_overrides_dict = json.loads(model_config_overrides_dict) + if not isinstance(model_config_overrides_dict, dict): + raise TypeError( + "model_config_overrides_dict must be a dict or a JSON object string" + )🤖 Prompt for AI Agents