diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py index df953bb90..7d955c5ca 100644 --- a/modelopt/torch/_compress/compress.py +++ b/modelopt/torch/_compress/compress.py @@ -28,7 +28,7 @@ from omegaconf import DictConfig from puzzle_tools.runtime import IRuntime -from modelopt.torch._compress.tools.hydra import initialize_hydra_config_for_dir +from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir def compress( diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index 1cbfa5f30..13d418b69 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -15,6 +15,9 @@ """ Compress NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146). + +It is used by mtn.convert() to convert a model from HF format to DeciLM format + do pruning scoring +and save pruned checkpoints, and by mtn.search() to perform the MIP-based NAS search. """ import datetime @@ -31,7 +34,7 @@ from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, ) -from modelopt.torch._compress.tools.hydra import initialize_hydra_config_for_dir +from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir from modelopt.torch._compress.tools.logger import mprint from modelopt.torch._compress.tools.runtime import NativeDdpRuntime from modelopt.torch.nas.conversion import NASModeRegistry diff --git a/modelopt/torch/_compress/tools/checkpoint_utils.py b/modelopt/torch/_compress/tools/checkpoint_utils.py index 4a05f82bb..43d3c4364 100644 --- a/modelopt/torch/_compress/tools/checkpoint_utils.py +++ b/modelopt/torch/_compress/tools/checkpoint_utils.py @@ -14,6 +14,11 @@ # limitations under the License. # mypy: ignore-errors +""" +It provides general utilities for loading and initializing PyTorch model checkpoints, +particularly for DeciLM models. +""" + import concurrent.futures import warnings from functools import partial @@ -51,7 +56,7 @@ def load_state_dict(checkpoint_dir: Path | str) -> dict[str, torch.Tensor]: if (checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME).exists() or ( checkpoint_dir / SAFE_WEIGHTS_NAME ).exists(): - from utils.sharded_checkpoint_utils import ( + from modelopt.torch._compress.tools.sharded_checkpoint_utils import ( load_sharded_state_dict, # local import to avoid circular import ) diff --git a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py index c686c1027..3c73498d5 100644 --- a/modelopt/torch/_compress/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/_compress/tools/checkpoint_utils_hf.py @@ -14,6 +14,11 @@ # limitations under the License. # mypy: ignore-errors +""" +Provides utilities for loading and saving PyTorch model checkpoints in the Hugging Face format, +particularly for DeciLM models. +""" + import concurrent.futures import fcntl import os @@ -26,15 +31,16 @@ from typing import Any, BinaryIO import torch -from logger import mprint -from puzzle_tools import deci_lm_hf_code -from puzzle_tools.common import infer_weights_dtype -from puzzle_tools.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM -from puzzle_tools.robust_json import json_dumps from safetensors.torch import save_file as safe_save_file from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from utils.post_init_sparse import SparsityMethod + +from modelopt.torch._compress.decilm import deci_lm_hf_code +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch._compress.tools.common import infer_weights_dtype +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.post_init_sparse import SparsityMethod +from modelopt.torch._compress.tools.robust_json import json_dumps SAFETENSORS_SUBBLOCKS_DIR_NAME = "subblocks_safetensors" PTH_SUBBLOCKS_DIR_NAME = "subblocks" diff --git a/modelopt/torch/_compress/tools/hydra_utils.py b/modelopt/torch/_compress/tools/hydra_utils.py new file mode 100644 index 000000000..64c403565 --- /dev/null +++ b/modelopt/torch/_compress/tools/hydra_utils.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for hydra config initialization. +""" + +import datetime +import random +from pathlib import Path + +from hydra import compose, initialize, initialize_config_dir +from hydra.utils import get_object +from omegaconf import DictConfig, OmegaConf + + +def warmup_steps(tokens: int, block: int, mbs: int, pct: float = 0.05) -> int: + """ + Calculate warmup steps based on total tokens, block size, micro batch size, and warmup percentage. + Used as a resolver in hydra configs. + """ + steps = (int(tokens) // int(block)) // int(mbs) + w = pct * steps + return max(1, round(w)) + + +def register_hydra_resolvers(): + OmegaConf.register_new_resolver("to_path", lambda x: Path(x)) + OmegaConf.register_new_resolver( + "random_int", lambda low, high: random.randint(int(low), int(high)) + ) + OmegaConf.register_new_resolver( + "timedelta_minutes", lambda x: datetime.timedelta(minutes=x) if x is not None else None + ) + OmegaConf.register_new_resolver("warmup_steps", lambda t, b, m, p: warmup_steps(t, b, m, p)) + OmegaConf.register_new_resolver("get_object", lambda x: get_object(x)) + + +def initialize_hydra_config_for_dir( + config_dir: str, config_name: str, overrides: list[str] +) -> DictConfig: + """Initialize a hydra config from an absolute path for a config directory + + Args: + config_dir (str): + config_name (str): + overrides (List[str]): + + Returns: + DictConfig: + """ + + with initialize_config_dir(version_base=None, config_dir=config_dir): + args = compose(config_name, overrides) + args._set_flag("allow_objects", True) + OmegaConf.resolve(args) # resolve object attributes + OmegaConf.set_struct(args, False) + + return args + + +def initialize_hydra_config(config_path: str, config_name: str, overrides: list[str]) -> DictConfig: + with initialize(version_base=None, config_path=config_path): + args = compose(config_name, overrides) + args._set_flag("allow_objects", True) + OmegaConf.resolve(args) # resolve object attributes + OmegaConf.set_struct(args, False) + + return args diff --git a/modelopt/torch/_compress/tools/post_init_sparse.py b/modelopt/torch/_compress/tools/post_init_sparse.py new file mode 100644 index 000000000..824d0856c --- /dev/null +++ b/modelopt/torch/_compress/tools/post_init_sparse.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors +import torch +from torch import nn +from torch.nn.utils.prune import custom_from_mask + +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM + +""" +Converts a state dictionary from PyTorch's pruning format (with _orig and _mask suffixes) +into a standard format with sparsified weights. +""" + + +class SparsityMethod: + def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + gets a model state_dict, returns a state_dict-like mask_dict with masks + """ + + @staticmethod + def fix_state_dict_inplace(state_dict, verbose=False, change_dtype=False): + sparsity_masks = {} + for name in list(state_dict.keys()): + original_name = name.replace("_orig", "") + mask_name = original_name + "_mask" + if name[-4:] == "orig" and mask_name in state_dict: + val = state_dict[name] + mask = state_dict[name[:-4] + "mask"] + val[mask == 0] = 0 + sparsity = (val == 0).sum() / mask.numel() + sparsity_masks[original_name[:-7]] = mask + if verbose: + print(f"fix_state_dict_inplace: {name} {sparsity=}") + del state_dict[mask_name] + del state_dict[name] + state_dict[original_name] = val + if change_dtype: + for name in state_dict: + state_dict[name] = state_dict[name].to(torch.bfloat16) + return state_dict, sparsity_masks + + def filter_function(self): + pass + + def apply_masks(self, model: nn.Module, mask_dict: dict[str, torch.Tensor]) -> None: + for name, module in model.named_modules(): + if name in mask_dict: + custom_from_mask(module, "weight", mask_dict[name].to(module.weight.device)) + print(name) + print(torch.sum(mask_dict[name]) / mask_dict[name].numel()) + + def do_sparsity(self, model: DeciLMForCausalLM, mask_dict=None): + full_name_layers = [] + for block_idx, block_config in enumerate(model.config.block_configs): + ffn_names = block_config.ffn.sparsify # layers_to_sparsify_pattern[block_idx] + att_name = block_config.attention.sparsify + block = model.model.layers[block_idx] + if hasattr(block, "mlp"): + for name, m in block.mlp.named_modules(): + if isinstance(m, torch.nn.Linear) and self.filter_function(name, ffn_names): + full_name_layers.append( + "model.layers." + str(block_idx) + "." + "mlp." + name + ) + if hasattr(block, "self_attn"): + for name, m in block.self_attn.named_modules(): + if isinstance(m, torch.nn.Linear) and self.filter_function(name, att_name): + full_name_layers.append( + "model.layers." + str(block_idx) + "." + "self_attn." + name + ) + + if mask_dict is None: + state_dict_for_sparsifying = { + k.rstrip(".weight"): v + for k, v in model.state_dict().items() + if k.rstrip(".weight") in full_name_layers + } + mask_dict = self.calculate_masks(state_dict_for_sparsifying) + # print('Apply sparsity') + # print(full_name_layers) + # print(model.state_dict().keys()) + # print(list(mask_dict.keys())) + + self.apply_masks(model, mask_dict) + + +class SparsityMethod2o4(SparsityMethod): + def calculate_masks(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + gets a model state_dict, returns a state_dict-like mask_dict with masks + """ + mask_dict = {} + for key, val in state_dict.items(): + orig_size = val.shape + scores = val.flatten() ** 2 + mask = self.create_mask(scores) + mask = mask.reshape(orig_size) + mask_dict[key] = mask + return mask_dict + + def create_mask(self, score, value=0): + score = score # .cpu() + orig_size = score.shape + score = score.view(-1, 4) + mask = torch.zeros(score.shape) + values, indices = torch.topk(score, 2, dim=1) + rows = torch.arange(mask.size(0)).unsqueeze(-1) + mask[rows, indices] = 1 + mask = mask.view(orig_size) + return mask # dev = score.device, return mask.to(dev) + + @staticmethod + def filter_function(name, modules_to_sparsify_in_block): + if modules_to_sparsify_in_block is None: + return False + return name in modules_to_sparsify_in_block diff --git a/modelopt/torch/_compress/tools/robust_json.py b/modelopt/torch/_compress/tools/robust_json.py new file mode 100644 index 000000000..dbb561b82 --- /dev/null +++ b/modelopt/torch/_compress/tools/robust_json.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +""" +Provides a robust JSON encoder that can handle various types of objects, +including dataclasses, paths, enums, namespaces, and functions. +""" + +import argparse +import dataclasses +import datetime +import inspect +import json +from enum import Enum +from pathlib import Path +from typing import Any + +from omegaconf import DictConfig, ListConfig, OmegaConf + + +class RobustJSONEncoder(json.JSONEncoder): + def default(self, o): + if dataclasses.is_dataclass(o): + return dataclasses.asdict(o) + if isinstance(o, Path): + return str(o) + if isinstance(o, Enum): + return o.name + if isinstance(o, argparse.Namespace): + return vars(o) + if type(o).__name__ == "dtype": + return str(o) + if isinstance(o, (DictConfig, ListConfig)): + return OmegaConf.to_container(o, resolve=True) + if inspect.isfunction(o) or inspect.ismethod(o): + if o.__module__ == "__main__": + # User-defined function in main — fallback to just the name + return o.__name__ + return f"{o.__module__}.{o.__qualname__}" + if isinstance(o, datetime.timedelta): + return str(o) + return super().default(o) + + +def json_dumps(obj: Any) -> str: + return json.dumps(obj, cls=RobustJSONEncoder, indent=2) + + +def json_dump(obj: Any, path: Path | str) -> None: + path = Path(path) + path.parent.mkdir(exist_ok=True, parents=True) + json_text = json_dumps(obj) + path.write_text(json_text) + + +def json_load(path: Path | str) -> dict: + path = Path(path) + text = path.read_text() + return json.loads(text) diff --git a/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py new file mode 100644 index 000000000..91fcb5ebd --- /dev/null +++ b/modelopt/torch/_compress/tools/sharded_checkpoint_utils.py @@ -0,0 +1,422 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +""" +Provides utilities for distributed loading, saving, and manipulation of +large language model checkpoints across multiple GPUs/processes. +""" + +import json +from collections.abc import Iterable, Mapping +from pathlib import Path +from typing import Literal, cast + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +from huggingface_hub import split_torch_state_dict_into_shards +from safetensors import safe_open +from safetensors.torch import load_file as safe_load_file +from safetensors.torch import save_file as safe_save_file +from tqdm import tqdm +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME +from transformers.utils.hub import cached_file, get_checkpoint_shard_files +from typing_extensions import override +from utils.utils import EmptyInitOnDevice + +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( + DeciLMDecoderLayer, + DeciLMForCausalLM, + rope_type_to_class, +) +from modelopt.torch._compress.tools.checkpoint_utils import load_model_config, load_state_dict +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.runtime import IRuntime + + +class DummyModule(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.register_load_state_dict_post_hook(self.load_state_dict_post_hook) + + @staticmethod + def load_state_dict_post_hook( + module: torch.nn.Module, incompatible_keys: torch.nn.modules.module._IncompatibleKeys + ) -> None: + incompatible_keys.missing_keys.clear() + incompatible_keys.unexpected_keys.clear() + + +class DummyBlock(DummyModule): + def __init__(self, config: DeciLMConfig, block_index: int): + super().__init__() + self.config = config + self.block_index = block_index + + @override + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor | tuple[torch.Tensor, None]: + if self.config.block_return_only_hidden_states: + return x + else: + return x, None + + +class DummyWTE(DummyModule): + def __init__(self, config: DeciLMConfig, dtype: torch.dtype | None = None): + super().__init__() + self.n_embd = config.get_hidden_size() + self.dtype = dtype + + @override + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + B, T = input_ids.shape # noqa: N806 + result = torch.ones((B, T, self.n_embd), dtype=self.dtype, device=input_ids.device) + return result + + +class DummyLMHead(DummyModule): + def __init__(self, config: DeciLMConfig): + super().__init__() + self.vocab_size = config.vocab_size + + @override + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, C = x.shape # noqa: N806 + result = torch.ones((B, T, self.vocab_size), dtype=x.dtype, device=x.device) + return result + + +def create_local_shard_(model: DeciLMForCausalLM, owned_block_indexes: set[int]): + all_block_indexes = set(range(len(model.model.layers))) + has_first_block = 0 in owned_block_indexes + has_last_block = max(all_block_indexes) in owned_block_indexes + + unowned_block_indexes = all_block_indexes - owned_block_indexes + for block_index in unowned_block_indexes: + model.model.layers[block_index] = cast( + "DeciLMDecoderLayer", DummyBlock(model.config, block_index) + ) + + if not has_first_block: + model.set_input_embeddings(DummyWTE(model.config)) + + if not has_last_block: + model.model.set_final_layer_norm(nn.Identity()) + if not (model.config.tie_word_embeddings and has_first_block): + model.set_output_embeddings(DummyLMHead(model.config)) + + return model + + +def create_dummy_model( + model_config: DeciLMConfig, + dtype: torch.dtype, +) -> DeciLMForCausalLM: + with torch.device("meta"): + model = DeciLMForCausalLM(model_config) + + rope_cls = rope_type_to_class[model_config.position_embedding_type] + model.model.rotary_emb = rope_cls(config=model.config) + + model.model.set_input_embeddings(DummyWTE(model.config, dtype)) + model.model.set_final_layer_norm(nn.Identity()) + model.set_output_embeddings(DummyLMHead(model.config)) + + for block_index in range(model_config.get_num_hidden_layers()): + model.model.layers[block_index] = DummyBlock(model.config, block_index) + + return model + + +def load_and_shard_model( + runtime: IRuntime, + checkpoint_path: str | Path, + owned_block_indexes: set[int] | Literal["auto"] = "auto", + model_config: DeciLMConfig | None = None, + model_config_overrides: Mapping | None = None, +) -> DeciLMForCausalLM: + checkpoint_path = Path(checkpoint_path) + with runtime.device: + if model_config is None: + model_config = load_model_config( + checkpoint_path, model_config_overrides, ignore_unexpected_config_keys=True + ) + + if owned_block_indexes == "auto": + owned_block_indexes = set( + np.array_split(np.arange(model_config.get_num_hidden_layers()), runtime.world_size)[ + runtime.global_rank + ] + ) + + mprint("Initializing model shards") + model_shard = create_sharded_model( + runtime=runtime, + model_config=model_config, + owned_block_indexes=owned_block_indexes, + ) + + if (checkpoint_path / SAFE_WEIGHTS_NAME).exists() or ( + checkpoint_path / SAFE_WEIGHTS_INDEX_NAME + ).exists(): + mprint("Loading shard state_dict from safetensors") + shard_keys = [ + *[name for name, _ in model_shard.named_parameters()], + *[name for name, _ in model_shard.named_buffers()], + ] + shard_state_dict = load_sharded_state_dict( + model_name_or_path=str(checkpoint_path), + keys_to_load=shard_keys, + device=runtime.device, + ) + + new_names = set(shard_state_dict.keys()) + mprint(f"{new_names=}") + model_shard.load_state_dict(shard_state_dict, assign=True) + + del shard_state_dict + + if model_config.tie_word_embeddings and (0 in owned_block_indexes): + # re-tie the weights in case the connection was severed + model_shard.tie_weights() + else: + mprint("Loading state_dict in main process") + state_dict = load_state_dict(checkpoint_path) if runtime.is_main_process else None + + mprint("Distributing model to shards") + load_state_dict_to_shards( + runtime=runtime, model_shard=model_shard, loaded_state_dict=state_dict + ) + del state_dict + + model_shard.type(runtime.dtype) + + params_on_meta_device = [ + param_name + for param_name, param in model_shard.named_parameters() + if param.device == torch.device("meta") + ] + assert len(params_on_meta_device) == 0, ( + f"[global_rank={runtime.global_rank}] Couldn't load params {params_on_meta_device}" + ) + + return model_shard + + +def create_sharded_model( + runtime: IRuntime, + model_config: DeciLMConfig, + owned_block_indexes: set[int], + device: str | torch.device | None = "meta", + dtype: torch.dtype | None = torch.float32, +): + if isinstance(device, str): + device = torch.device(device) + + runtime.wait_for_everyone() + + with EmptyInitOnDevice(device="meta", dtype=dtype): + model = DeciLMForCausalLM(model_config) + create_local_shard_(model=model, owned_block_indexes=owned_block_indexes) + + if device != torch.device("meta"): + local_shard_state_dict = { + k: torch.empty_like(v, device=device) for k, v in model.state_dict().items() + } + + model.load_state_dict(local_shard_state_dict, assign=True) + + return model + + +def load_state_dict_to_shards( + runtime: IRuntime, model_shard: torch.nn.Module, loaded_state_dict: dict | None = None +) -> None: + from sewing_kit.utils import distributed_isend_obj, distributed_recv_obj + + model_shard.to("meta") + local_state_dict_keys = list(model_shard.state_dict().keys()) + + if runtime.is_main_process: + gathered_state_dict_keys = [None] * runtime.world_size + torch.distributed.gather_object(local_state_dict_keys, gathered_state_dict_keys) + + assert loaded_state_dict is not None + loaded_state_dict = {k.replace("_orig_mod.", ""): v for k, v in loaded_state_dict.items()} + + works: list[torch.distributed.Work] = [] + for i, shard_keys in enumerate(gathered_state_dict_keys[1:]): + process_id = i + 1 + shard_state_dict = {k: v for k, v in loaded_state_dict.items() if k in shard_keys} + process_works = distributed_isend_obj(shard_state_dict, process_id) + works.extend(process_works) + + for work in works: + work.wait() + + shard_state_dict = { + k: v for k, v in loaded_state_dict.items() if k in local_state_dict_keys + } + else: + torch.distributed.gather_object(local_state_dict_keys) + shard_state_dict = distributed_recv_obj() + + print(f"{runtime.global_rank=} loaded state_dict shard") + + missing_keys, unexpected_keys = model_shard.load_state_dict( + shard_state_dict, strict=False, assign=True + ) + assert len(unexpected_keys) == 0 + assert all("dummy_param" in key for key in missing_keys) + + model_shard.to(runtime.device) + + runtime.wait_for_everyone() + + +def save_sharded_model( + runtime: IRuntime, + model_shard: torch.nn.Module | dict[str, torch.Tensor], + out_path: str | Path, +): + """ + out_path is usually output_checkpoint_path / "model.safetensors" + """ + runtime.wait_for_everyone() + + if isinstance(model_shard, torch.nn.Module): + shard_state_dict = model_shard.state_dict() + elif isinstance(model_shard, dict): + shard_state_dict = model_shard + else: + raise ValueError(f"Unrecognized model shard type: {type(model_shard)}") + + shard_state_dict = {k: v.cpu() for k, v in shard_state_dict.items()} + total_shard_size = sum( + weight.numel() * weight.element_size() for weight in shard_state_dict.values() + ) + + num_shards = runtime.world_size + idx = runtime.global_rank + + out_path = Path(out_path) + shard_file = out_path.with_stem(f"{out_path.stem}-{idx + 1:05d}-of-{num_shards:05d}") + + shard_metadata = { + "total_shard_size": total_shard_size, + "shard_keys": list(shard_state_dict.keys()), + "shard_file": str(shard_file), + } + + if runtime.is_main_process: + shard_metadatas = [{} for _ in range(runtime.world_size)] + torch.distributed.gather_object(shard_metadata, shard_metadatas, dst=0) + total_size = sum(x["total_shard_size"] for x in shard_metadatas) + metadata = {"total_size": total_size} + weight_map: dict[str, str] = {} + for shard_metadata in shard_metadatas: + weight_map.update( + {k: Path(shard_metadata["shard_file"]).name for k in shard_metadata["shard_keys"]} + ) + + index = {"metadata": metadata, "weight_map": weight_map} + index_path = Path(str(out_path) + ".index.json") + index_path.write_text(json.dumps(index, indent=2)) + + else: + torch.distributed.gather_object(shard_metadata, dst=0) + + if out_path.suffix == ".safetensors": + safe_save_file(shard_state_dict, shard_file, metadata={"format": "pt"}) + else: + torch.save(shard_state_dict, shard_file) + + runtime.wait_for_everyone() + + +def save_sharded_state_dict( + state_dict: dict[str, torch.Tensor], + save_directory: str | Path, + max_shard_size: str = "10GB", +) -> None: + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + state_dict = {k: v.cpu() for k, v in state_dict.items()} + + state_dict_split = split_torch_state_dict_into_shards(state_dict, max_shard_size=max_shard_size) + + for shard_filename, param_names in tqdm( + state_dict_split.filename_to_tensors.items(), desc="saving sharded state dict" + ): + shard_path = save_directory / shard_filename + shard = {param_name: state_dict[param_name] for param_name in param_names} + safe_save_file(shard, shard_path, metadata={"format": "pt"}) + + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + index_path = save_directory / SAFE_WEIGHTS_INDEX_NAME + index_path.write_text(json.dumps(index, indent=2)) + + +def load_sharded_state_dict( + model_name_or_path: str | Path, + keys_to_load: Iterable[str] | None = None, + device: torch.device | str = "cpu", +) -> dict[str, torch.Tensor]: + """ + keys_to_load: entire state_dict if None, else partial state_dict containing only these keys + """ + shard_paths = _resolve_shard_paths(model_name_or_path) + # print(f"shard_paths: {shard_paths}") + partial_state_dict = {} + for safetensors_path in shard_paths: + if keys_to_load is None: + shard = safe_load_file(safetensors_path) + partial_state_dict.update(shard) + else: + with safe_open(safetensors_path, framework="pt", device=str(device)) as f: + for key in f: + if key in keys_to_load: + partial_state_dict[key] = f.get_tensor(key) + return partial_state_dict + + +def _resolve_shard_paths(model_name_or_path: str) -> list[str]: + try: + unsharded_path = cached_file(model_name_or_path, SAFE_WEIGHTS_NAME) + return [unsharded_path] + except OSError: + index_path = cached_file(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + shard_paths, _ = get_checkpoint_shard_files(model_name_or_path, index_path) + return shard_paths + + +def is_in_safetensors_format(checkpoint_dir: Path) -> bool: + return len(list(checkpoint_dir.glob("*.safetensors"))) > 0 + + +def load_state_dict_shapes(model_name_or_path: str | Path) -> dict[str, tuple]: + shard_paths = _resolve_shard_paths(model_name_or_path) + state_dict_shapes = {} + for safetensors_path in shard_paths: + with safe_open(safetensors_path, framework="pt") as f: + for key in f: + state_dict_shapes[key] = tuple(f.get_tensor(key).shape) + return state_dict_shapes diff --git a/tests/experimental/torch/_compress/compress_test_utils.py b/tests/experimental/torch/_compress/compress_test_utils.py index f0704f6c8..160098922 100644 --- a/tests/experimental/torch/_compress/compress_test_utils.py +++ b/tests/experimental/torch/_compress/compress_test_utils.py @@ -19,9 +19,10 @@ import torch from datasets import Dataset, DatasetDict -from puzzle_tools.hydra_utils import register_hydra_resolvers from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase +from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers + def setup_test_model_and_data( project_root_path: Path, diff --git a/tests/experimental/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py b/tests/gpu/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py similarity index 100% rename from tests/experimental/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py rename to tests/gpu/torch/_compress/decilm/converters/test_convert_llama3_config_to_decilm_config.py diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index a38322d14..cd4e34ca1 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + import pytest import torch import torch.distributed as dist @@ -57,3 +59,9 @@ def set_torch_dtype(request): @pytest.fixture(scope="session", autouse=True) def enable_hf_checkpointing(): mto.enable_huggingface_checkpointing() + + +@pytest.fixture +def project_root_path(request: pytest.FixtureRequest) -> Path: + """Fixture providing the project root path for tests.""" + return Path(request.config.rootpath)