diff --git a/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py new file mode 100644 index 0000000000..3617bdb1c2 --- /dev/null +++ b/modelopt/torch/_compress/activation_scoring/score_pruning_activations.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import hydra +import torch +from omegaconf import DictConfig +from utils.parsing import format_global_config + +from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.runtime import BaseRuntime, NativeDdpRuntime +from modelopt.torch._compress.tools.validate_model import validate_model +from modelopt.torch._compress.utils.dist_utils import is_distributed + + +def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool: + """ + Determine if the activation hook method has proper checkpoint support implemented. + + Args: + activation_hooks_kwargs: Hook configuration + + Returns: + bool: True if the hook method has save_state/load_state implemented + """ + method = activation_hooks_kwargs.get("method", "") + + # Methods with implemented checkpoint support + supported_methods = { + "iterative", # IterativeChannelContributionHook: save_state/load_state implemented + "independent", # IndependentChannelContributionHook: save_state/load_state implemented + "stats", # RouterStatsHook: save_state/load_state implemented + "ranked_choice_voting", # RankedChoiceVotingHook: save_state/load_state implemented + } + + return method in supported_methods + + +def check_scoring_completion( + activations_log_dir: str, runtime, activation_hooks_kwargs=None +) -> bool: + """ + Check if scoring is already completed by looking for the expected output files. + Also checks if the scoring method is safe for resume. + + Args: + activations_log_dir: Directory where activation logs should be stored + runtime: Runtime object for distributed processing + activation_hooks_kwargs: Hook configuration to check if resume is safe + + Returns: + bool: True if scoring is completed (has rank files and args.json) + """ + # Only check completion on main process (or if no distributed runtime) + if runtime is None or runtime.is_main_process: + log_dir = Path(activations_log_dir) + + # Check if directory exists + if not log_dir.exists(): + return False + + # Check for rank files (at least rank_0.pth should exist) + rank_files = list(log_dir.glob("rank_*.pth")) + + if not rank_files: + return False + + # Check for args.json (created by main process) + args_file = log_dir / "args.json" + has_args_json = args_file.exists() + + # Check for completion: if we have rank files and args.json, scoring is complete + if rank_files and has_args_json: + # Add optional completion info for debugging + mprint(f"Found completed scoring in {activations_log_dir}") + mprint(f" - Found {len(rank_files)} rank files") + mprint(f" - Found args.json: {has_args_json}") + + return True + + return False + + +def should_skip_scoring_completely(cfg: DictConfig, runtime) -> bool: + """ + Determine if we should skip scoring entirely (only if 100% complete). + Partial progress should proceed to validate_model for proper resume. + + Args: + cfg: Configuration object + runtime: Runtime object for distributed processing + + Returns: + bool: True if we should skip scoring (100% completed), False if we should run/resume it + """ + # Check if activations_log_dir is specified + if not hasattr(cfg.pruning, "activations_log_dir") or cfg.pruning.activations_log_dir is None: + mprint("No activations_log_dir specified, running scoring") + return False + + # Check for force restart flag + force_restart = getattr(cfg.pruning, "force_restart_scoring", False) + if force_restart: + mprint("Force restart flag set, will restart scoring regardless of existing artifacts") + return False + + # Get hook configuration to check if resume is mathematically safe + activation_hooks_kwargs = getattr(cfg.pruning, "activation_hooks_kwargs", {}) + + # Check if scoring is already completed + is_completed = check_scoring_completion( + cfg.pruning.activations_log_dir, runtime, activation_hooks_kwargs + ) + + # Broadcast the result to all processes in distributed mode + if runtime is not None and runtime.world_size > 1: + should_skip = [is_completed] # Use list for mutable object + torch.distributed.broadcast_object_list(should_skip, src=0) + is_completed = should_skip[0] + + if is_completed: + mprint("Scoring 100% completed, skipping...") + + return is_completed + + +# Old progress tracking removed - checkpoint manager handles all progress tracking + + +def launch_score_activations(cfg: DictConfig, runtime): + # Check if we should skip scoring entirely (only if 100% complete) + if should_skip_scoring_completely(cfg, runtime): + return + + mprint("Starting pruning activation scoring...") + + # The checkpoint manager inside validate_model handles all progress tracking + validate_model(args=cfg.pruning, runtime=runtime) + + +@hydra.main("", version_base="1.3") +def main(cfg: DictConfig) -> None: + cfg = hydra.utils.instantiate(cfg) + mprint(format_global_config(cfg, title="Score Pruning Activations")) + + _runtime = ( + NativeDdpRuntime( + dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes") + ) + if is_distributed() + else BaseRuntime(dtype=torch.bfloat16) + ) + with _runtime as runtime: + launch_score_activations(cfg, runtime) + runtime.wait_for_everyone() + + +if __name__ == "__main__": + register_hydra_resolvers() + main() diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index 13d418b69d..84af06b137 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -26,11 +26,11 @@ import build_library_and_stats import mip_and_realize_models import pruning_ckpts -import score_pruning_activations import scoring import torch from torch import nn +from modelopt.torch._compress.activation_scoring import score_pruning_activations from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) diff --git a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py index 91fcb5ebd5..a27cd50771 100644 --- a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py @@ -29,6 +29,7 @@ import torch.distributed import torch.nn as nn from huggingface_hub import split_torch_state_dict_into_shards +from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file @@ -36,17 +37,16 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME from transformers.utils.hub import cached_file, get_checkpoint_shard_files from typing_extensions import override -from utils.utils import EmptyInitOnDevice from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( DeciLMDecoderLayer, - DeciLMForCausalLM, rope_type_to_class, ) from modelopt.torch._compress.tools.checkpoint_utils import load_model_config, load_state_dict from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.tools.runtime import IRuntime +from modelopt.torch._compress.utils.utils import EmptyInitOnDevice class DummyModule(nn.Module): @@ -392,7 +392,7 @@ def load_sharded_state_dict( partial_state_dict.update(shard) else: with safe_open(safetensors_path, framework="pt", device=str(device)) as f: - for key in f: + for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable if key in keys_to_load: partial_state_dict[key] = f.get_tensor(key) return partial_state_dict @@ -417,6 +417,6 @@ def load_state_dict_shapes(model_name_or_path: str | Path) -> dict[str, tuple]: state_dict_shapes = {} for safetensors_path in shard_paths: with safe_open(safetensors_path, framework="pt") as f: - for key in f: + for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable state_dict_shapes[key] = tuple(f.get_tensor(key).shape) return state_dict_shapes diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py new file mode 100644 index 0000000000..e264ea6813 --- /dev/null +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import textwrap +from pathlib import Path + +import torch.distributed +from omegaconf import DictConfig +from torch import nn +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizerBase, +) +from utils.activation_hooks.utils import register_activation_hooks +from utils.data.dataloaders import create_validation_dataloader +from utils.parsing import simple_parse_args_string +from utils.validate_runtime_pipeline import HiddenStatesAndLMHead, calculate_losses_pipeline +from utils.validation import calculate_losses + +from modelopt.torch._compress.tools.checkpoint_utils_hf import load_checkpoint +from modelopt.torch._compress.tools.logger import aprint, mprint +from modelopt.torch._compress.tools.runtime import IRuntime, NativeDdpRuntime +from modelopt.torch._compress.tools.sharded_checkpoint_utils import load_and_shard_model + +# #TODO:Import slack from root utils directory +# root_path = os.path.join(os.path.dirname(__file__), "..", "..") +# if root_path not in sys.path: +# sys.path.append(root_path) +# from utils.slack import send_slack_message + +""" +Two goals: +1) Calculate lm loss and token accuracy for a model. +May raise lots of NCCL warnings when it finishes, don't be alarmed. +Can be used to validate a HuggingFace model. +Automatically uses pipeline parallelism via device_map="auto". + +2) Register hooks to capture the inputs and the outputs of pytorch modules. +For example, to collect activations scores for various layers (ffn, layer_norm, etc.) +that are used for pruning (ffn_hidden_size, embedding_pruning, etc). +See --activations_log_dir and --activation_hooks_kwargs args arguments. + +""" + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name_or_path", + type=str, + default=None, + help="Required unless a model is passed to the function", + ) + parser.add_argument("--dataset_path", type=str, required=True) + + parser.add_argument("--output_dir_name", type=str, default="validation") + parser.add_argument( + "--calculate_full_score_ablations", + action="store_true", + help="Calculates a diverse suite of teacher similarity scores. " + "By default only a small suite is calculated, which is good for most use-cases.", + ) + + parser.add_argument("--tokenizer_name", type=str, default=None) + parser.add_argument("--data_column", type=str, default="content") + # TODO: Add help text for FIM rate, also for others less obvious args + parser.add_argument("--fim_rate", type=float, default=0) + parser.add_argument("--fim_spm_rate", type=float, default=0) + parser.add_argument("--eval_samples", type=int, default=None) + parser.add_argument("--block_size", type=int, default=4096) + parser.add_argument("--micro_batch_size", type=int, default=4) + parser.add_argument("--val_dataset_name", type=str, default="__auto__") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--source_datasets_to_discard", nargs="+", type=str) + parser.add_argument("--bos_rate", type=float, default=1.0) + parser.add_argument("--shuffle_seed", type=int, default=None) + parser.add_argument("--varlen", action="store_true") + parser.add_argument("--pipeline_parallel", action="store_true") + parser.add_argument("--write_results", action="store_true") + parser.add_argument("--activations_log_dir", type=str, default=None) + parser.add_argument( + "--activation_hooks_kwargs", + type=str, + default=None, + help="Comma separated string arguments, e.g. `arg1=val1,arg2=val2`", + ) + parser.add_argument( + "--calc_losses_on_cpu", + action="store_true", + help="Very slow, not recommended. Can help avoid OOM.", + ) + return parser + + +def parse_args() -> argparse.Namespace: + parser = build_arg_parser() + args, unknown_args = parser.parse_known_args() + return args + + +@torch.no_grad() +def validate_model( + args: argparse.Namespace | DictConfig, + model: PreTrainedModel | None = None, + tokenizer: PreTrainedTokenizerBase | None = None, + target_hidden_states_per_batch: list[torch.Tensor] | None = None, + return_hidden_states: bool = False, + runtime: IRuntime | None = None, + calculate_full_score_ablations: bool = False, + val_dataloader: DataLoader | None = None, +) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: + if val_dataloader is None: + val_dataloader = ( + prepare_dataloader(args, tokenizer) + if (runtime is None or runtime.is_main_process) + else None + ) + validation_full_iters = ( + args.eval_samples // args.micro_batch_size + ) # model pipeline, single data rank + + model = prepare_model(args, model, runtime) + + just_model_forward = False + checkpoint_manager = None + activation_hooks = None + + if args.activations_log_dir is not None: + activation_hooks_kwargs = ( + simple_parse_args_string(args.activation_hooks_kwargs) + if isinstance(args.activation_hooks_kwargs, str) + else args.activation_hooks_kwargs + ) + activation_hooks_kwargs["validation_full_iters"] = validation_full_iters + + # Create activation hooks first + activation_hooks, hook_class = register_activation_hooks( + model=model, activation_hooks_kwargs=activation_hooks_kwargs + ) + + # Create checkpoint manager with hooks + from utils.checkpoint_manager import ScoringCheckpointManager + + mprint( + f"Creating checkpoint manager with {len(activation_hooks)} hooks for dir: {args.activations_log_dir}" + ) + checkpoint_manager = ScoringCheckpointManager( + checkpoint_dir=args.activations_log_dir, + runtime=runtime, + activation_hooks=activation_hooks, + checkpoint_interval=50, # Save every 50 batches + ) + + # Load existing checkpoint if available + mprint("Attempting to load existing checkpoint...") + checkpoint_data = checkpoint_manager.load_checkpoint() + if checkpoint_data: + mprint(f"Checkpoint loaded successfully: {checkpoint_data}") + else: + mprint("No checkpoint found, starting fresh") + just_model_forward = True + model.lm_head = nn.Identity() + + if runtime is None: + losses, hidden_states_per_batch = calculate_losses( + model=model, + dataloader=val_dataloader, + checkpoint_manager=checkpoint_manager, + ) + else: + losses, hidden_states_per_batch = calculate_losses_pipeline( + runtime=runtime, + stitched_model=model, + dataloader=val_dataloader, + target_hidden_states_per_batch=target_hidden_states_per_batch, + return_hidden_states=return_hidden_states, + calculate_full_score_ablations=calculate_full_score_ablations, + calc_on_cpu=args.calc_losses_on_cpu, + just_model_forward=just_model_forward, + checkpoint_manager=checkpoint_manager, + ) + + if losses is not None: + avg_losses = {loss_name: loss_log["avg"] for loss_name, loss_log in losses.items()} + + results_str = f""" + validate_model: + {args.model_name_or_path=} + Average losses = {avg_losses} + Actual num samples = {len(next(iter(losses.values()))["per_sample"])} + {args=} + """ + results_str = textwrap.dedent(results_str) + aprint(results_str) + if args.write_results: + Path(f"{args.model_name_or_path}/validate_model_results.txt").write_text(results_str) + # TODO: send_slack_message(results_str) + + if args.activations_log_dir is not None: + hook_class.dump_activations_logs(activation_hooks, args.activations_log_dir, args, runtime) + + return losses, hidden_states_per_batch + + +def prepare_model( + args: argparse.Namespace, + model: PreTrainedModel | None = None, + runtime: IRuntime | None = None, +) -> nn.Module: + if model is None: + assert args.model_name_or_path is not None + if runtime is not None: + model = load_and_shard_model( + runtime, + args.model_name_or_path, + model_config_overrides={"block_size": args.block_size}, + ) + else: + try: + model = load_checkpoint( + args.model_name_or_path, + model_config_overrides={"block_size": args.block_size}, + ignore_unexpected_config_keys=True, + ) + model.to("cuda") + except FileNotFoundError: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + torch_dtype="auto", + device_map="auto", + trust_remote_code=True, + ) + + model.eval() + return model + + +def prepare_dataloader( + args: argparse.Namespace, + tokenizer: PreTrainedTokenizerBase | None = None, +) -> DataLoader: + if tokenizer is None: + tokenizer_name = getattr(args, "tokenizer_name", None) + assert (tokenizer_name is not None) or (args.model_name_or_path is not None) + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name or args.model_name_or_path, trust_remote_code=True + ) + + val_dataloader = create_validation_dataloader( + accelerator=None, + seed=args.seed, + tokenizer=tokenizer, + block_size=args.block_size, + dataset=args.dataset_path, + content_field=args.data_column, + fim_rate=args.fim_rate, + fim_spm_rate=args.fim_spm_rate, + micro_batch_size=args.micro_batch_size, + eval_samples=args.eval_samples, + dataset_name=args.val_dataset_name, + source_datasets_to_discard=args.source_datasets_to_discard, + bos_rate=args.bos_rate, + varlen=args.varlen, + shuffle_seed=args.shuffle_seed, + load_dataset_fn=args.load_dataset_fn, + ) + + return val_dataloader + + +def main(): + args = parse_args() + if args.pipeline_parallel: + with NativeDdpRuntime(dtype=torch.bfloat16) as runtime: + validate_model(args=args, runtime=runtime) + else: + validate_model(args=args, runtime=None) + + +if __name__ == "__main__": + main() diff --git a/modelopt/torch/_compress/utils/dist_utils.py b/modelopt/torch/_compress/utils/dist_utils.py new file mode 100644 index 0000000000..84f8f2bab1 --- /dev/null +++ b/modelopt/torch/_compress/utils/dist_utils.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch.distributed as dist + + +def is_distributed(): + """ + From torchtune.utils.is_distributed() : https://docs.pytorch.org/torchtune/0.2/generated/torchtune.utils.is_distributed.html + """ + port = os.environ.get("MASTER_PORT", "") + addr = os.environ.get("MASTER_ADDR", "") + size = int(os.environ.get("WORLD_SIZE", 1)) + rank = int(os.environ.get("RANK", -1)) + avlb = dist.is_available() + return bool(port and addr and size > 1 and rank >= 0 and avlb) diff --git a/modelopt/torch/_compress/utils/utils.py b/modelopt/torch/_compress/utils/utils.py new file mode 100644 index 0000000000..ef952dfec6 --- /dev/null +++ b/modelopt/torch/_compress/utils/utils.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): + def __init__(self, device=None, dtype=None): + """ + Create tensors with given device and dtype and don't run initialization + (but instead use "empty tensors", i.e. uninitialized memory). + + device: `torch.device` to work with + dtype: `torch.dtype` to work with + + + Example:: + with EmptyInitOnDevice("cuda", dtype=torch.bfloat16): + model = LLaMA(model_config) + model.load_state_dict(torch.load("llama-lit/7B/lit-llama.pth"))""" + + self.device = device + self.dtype = dtype + + def __enter__(self): + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + return super().__exit__(exc_type, exc_val, exc_tb) + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if getattr(func, "__module__", None) == "torch.nn.init": + if "tensor" in kwargs: + return kwargs["tensor"] + else: + return args[0] + if ( + self.device is not None + and func in torch.utils._device._device_constructors() + and kwargs.get("device") is None + ): + kwargs["device"] = self.device + if ( + self.dtype is not None + and func in torch.utils._device._device_constructors() + and kwargs.get("dtype") is None + ): + kwargs["dtype"] = self.dtype + return func(*args, **kwargs)