Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
694c317
Add decilm modelling code
danielkorzekwa Nov 3, 2025
991659f
Add decilm modelling code.
danielkorzekwa Nov 3, 2025
8489cee
Add transformers codebase
danielkorzekwa Nov 3, 2025
f0afefe
Add transformers code
danielkorzekwa Nov 3, 2025
b3ed5bc
Add decilm modelling code
danielkorzekwa Nov 3, 2025
a700da5
Add decilm modelling code
danielkorzekwa Nov 3, 2025
b59b679
Correct licence headers
danielkorzekwa Nov 4, 2025
1abdf3e
Correct licence headers
danielkorzekwa Nov 4, 2025
66609b1
Add decilm code
danielkorzekwa Nov 4, 2025
7da0a8a
Add decilm code
danielkorzekwa Nov 4, 2025
6e09a81
Add decilm code
danielkorzekwa Nov 4, 2025
2e3f5da
Add decilm code
danielkorzekwa Nov 4, 2025
418890e
Add decilm code
danielkorzekwa Nov 4, 2025
01f4fc1
Make llama3 converter self-contained (no deps on internal Nvidia code)
danielkorzekwa Nov 4, 2025
c57eed4
Add common module
danielkorzekwa Nov 4, 2025
3dc37b3
module refactoring
danielkorzekwa Nov 4, 2025
10ffdfe
refactoring
danielkorzekwa Nov 5, 2025
27a4456
add shared_checkpointing_utils
danielkorzekwa Nov 5, 2025
b0e22b7
Add json tools
danielkorzekwa Nov 5, 2025
52e7827
add logger
danielkorzekwa Nov 5, 2025
f5c1c87
import refactoring
danielkorzekwa Nov 5, 2025
0aa6320
add post_init_sparse module
danielkorzekwa Nov 5, 2025
35d0dbc
Add post_init_sparse
danielkorzekwa Nov 5, 2025
e39a1ad
merginy hydra.py and hydra_utils.py
danielkorzekwa Nov 5, 2025
1bd0c67
Add integrationt test for attention pruning
danielkorzekwa Nov 5, 2025
0ecd52b
add score_pruning_activations
danielkorzekwa Nov 5, 2025
278c6b7
import refactoring
danielkorzekwa Nov 5, 2025
7a0af16
add dist_utils
danielkorzekwa Nov 5, 2025
0f0cbbd
Add validate_model
danielkorzekwa Nov 5, 2025
03af4f7
Merge branch 'feature/compress' into dkorzekwa/score_pruning_activati…
danielkorzekwa Nov 14, 2025
d3ff495
delete not used tokenizer
danielkorzekwa Nov 14, 2025
6189bec
fix logger import
danielkorzekwa Nov 14, 2025
6d786a5
Update modelopt/torch/_compress/tools/validate_model.py
danielkorzekwa Nov 17, 2025
4b403be
Remove lit-llama specific comments
danielkorzekwa Nov 17, 2025
c051630
remove 'lustre' specific folders from comments
danielkorzekwa Nov 17, 2025
23ac46a
remove lit-llama specific code
danielkorzekwa Nov 17, 2025
c1459bc
Add TODO
danielkorzekwa Nov 17, 2025
41dfd43
Remove unused quantization_mode from EmptyInitOnDevice
danielkorzekwa Nov 17, 2025
5ec40d8
Remove not used code from utils.py
danielkorzekwa Nov 17, 2025
f4de171
Regenerate licence header
danielkorzekwa Nov 17, 2025
54c1505
Removed not tested instructions
danielkorzekwa Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

import hydra
import torch
from omegaconf import DictConfig
from utils.parsing import format_global_config

from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers
from modelopt.torch._compress.tools.logger import mprint
from modelopt.torch._compress.tools.runtime import BaseRuntime, NativeDdpRuntime
from modelopt.torch._compress.tools.validate_model import validate_model
from modelopt.torch._compress.utils.dist_utils import is_distributed


def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool:
"""
Determine if the activation hook method has proper checkpoint support implemented.

Args:
activation_hooks_kwargs: Hook configuration

Returns:
bool: True if the hook method has save_state/load_state implemented
"""
method = activation_hooks_kwargs.get("method", "")

# Methods with implemented checkpoint support
supported_methods = {
"iterative", # IterativeChannelContributionHook: save_state/load_state implemented
"independent", # IndependentChannelContributionHook: save_state/load_state implemented
"stats", # RouterStatsHook: save_state/load_state implemented
"ranked_choice_voting", # RankedChoiceVotingHook: save_state/load_state implemented
}

return method in supported_methods


def check_scoring_completion(
activations_log_dir: str, runtime, activation_hooks_kwargs=None
) -> bool:
"""
Check if scoring is already completed by looking for the expected output files.
Also checks if the scoring method is safe for resume.

Args:
activations_log_dir: Directory where activation logs should be stored
runtime: Runtime object for distributed processing
activation_hooks_kwargs: Hook configuration to check if resume is safe

Returns:
bool: True if scoring is completed (has rank files and args.json)
"""
# Only check completion on main process (or if no distributed runtime)
if runtime is None or runtime.is_main_process:
log_dir = Path(activations_log_dir)

# Check if directory exists
if not log_dir.exists():
return False

# Check for rank files (at least rank_0.pth should exist)
rank_files = list(log_dir.glob("rank_*.pth"))

if not rank_files:
return False

# Check for args.json (created by main process)
args_file = log_dir / "args.json"
has_args_json = args_file.exists()

# Check for completion: if we have rank files and args.json, scoring is complete
if rank_files and has_args_json:
# Add optional completion info for debugging
mprint(f"Found completed scoring in {activations_log_dir}")
mprint(f" - Found {len(rank_files)} rank files")
mprint(f" - Found args.json: {has_args_json}")

return True

return False


def should_skip_scoring_completely(cfg: DictConfig, runtime) -> bool:
"""
Determine if we should skip scoring entirely (only if 100% complete).
Partial progress should proceed to validate_model for proper resume.

Args:
cfg: Configuration object
runtime: Runtime object for distributed processing

Returns:
bool: True if we should skip scoring (100% completed), False if we should run/resume it
"""
# Check if activations_log_dir is specified
if not hasattr(cfg.pruning, "activations_log_dir") or cfg.pruning.activations_log_dir is None:
mprint("No activations_log_dir specified, running scoring")
return False

# Check for force restart flag
force_restart = getattr(cfg.pruning, "force_restart_scoring", False)
if force_restart:
mprint("Force restart flag set, will restart scoring regardless of existing artifacts")
return False

# Get hook configuration to check if resume is mathematically safe
activation_hooks_kwargs = getattr(cfg.pruning, "activation_hooks_kwargs", {})

# Check if scoring is already completed
is_completed = check_scoring_completion(
cfg.pruning.activations_log_dir, runtime, activation_hooks_kwargs
)

# Broadcast the result to all processes in distributed mode
if runtime is not None and runtime.world_size > 1:
should_skip = [is_completed] # Use list for mutable object
torch.distributed.broadcast_object_list(should_skip, src=0)
is_completed = should_skip[0]

if is_completed:
mprint("Scoring 100% completed, skipping...")

return is_completed


# Old progress tracking removed - checkpoint manager handles all progress tracking


def launch_score_activations(cfg: DictConfig, runtime):
# Check if we should skip scoring entirely (only if 100% complete)
if should_skip_scoring_completely(cfg, runtime):
return

mprint("Starting pruning activation scoring...")

# The checkpoint manager inside validate_model handles all progress tracking
validate_model(args=cfg.pruning, runtime=runtime)


@hydra.main("", version_base="1.3")
def main(cfg: DictConfig) -> None:
cfg = hydra.utils.instantiate(cfg)
mprint(format_global_config(cfg, title="Score Pruning Activations"))

_runtime = (
NativeDdpRuntime(
dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes")
)
if is_distributed()
else BaseRuntime(dtype=torch.bfloat16)
)
with _runtime as runtime:
launch_score_activations(cfg, runtime)
runtime.wait_for_everyone()


if __name__ == "__main__":
register_hydra_resolvers()
main()
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
import build_library_and_stats
import mip_and_realize_models
import pruning_ckpts
import score_pruning_activations
import scoring
import torch
from torch import nn

from modelopt.torch._compress.activation_scoring import score_pruning_activations
from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import (
convert_llama3_to_decilm,
)
Expand Down
8 changes: 4 additions & 4 deletions modelopt/torch/_compress/tools/sharded_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,24 @@
import torch.distributed
import torch.nn as nn
from huggingface_hub import split_torch_state_dict_into_shards
from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
from tqdm import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
from typing_extensions import override
from utils.utils import EmptyInitOnDevice

from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig
from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import (
DeciLMDecoderLayer,
DeciLMForCausalLM,
rope_type_to_class,
)
from modelopt.torch._compress.tools.checkpoint_utils import load_model_config, load_state_dict
from modelopt.torch._compress.tools.logger import mprint
from modelopt.torch._compress.tools.runtime import IRuntime
from modelopt.torch._compress.utils.utils import EmptyInitOnDevice


class DummyModule(nn.Module):
Expand Down Expand Up @@ -392,7 +392,7 @@ def load_sharded_state_dict(
partial_state_dict.update(shard)
else:
with safe_open(safetensors_path, framework="pt", device=str(device)) as f:
for key in f:
for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable
if key in keys_to_load:
partial_state_dict[key] = f.get_tensor(key)
return partial_state_dict
Expand All @@ -417,6 +417,6 @@ def load_state_dict_shapes(model_name_or_path: str | Path) -> dict[str, tuple]:
state_dict_shapes = {}
for safetensors_path in shard_paths:
with safe_open(safetensors_path, framework="pt") as f:
for key in f:
for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable
state_dict_shapes[key] = tuple(f.get_tensor(key).shape)
return state_dict_shapes
Loading