diff --git a/modelopt/torch/_compress/build_library_and_stats.py b/modelopt/torch/_compress/build_library_and_stats.py new file mode 100644 index 0000000000..19bd4f03cc --- /dev/null +++ b/modelopt/torch/_compress/build_library_and_stats.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unified command that runs build_replacement_library followed by calc_subblock_stats. + +This script combines the functionality of both commands into a single workflow: +1. First, it builds the replacement library for the puzzle +2. Then, it calculates subblock statistics + +Usage: + + python modelopt.torch._compress.build_library_and_stats.py --config-dir configs --config-name Llama-3_1-8B puzzle_dir=/path/to/puzzle/dir dataset_path=/path/to/dataset + +The script uses the same Hydra configuration as the individual commands and supports +all the same configuration parameters for both build_replacement_library and calc_subblock_stats. +""" + +import hydra +from calc_subblock_stats import launch_calc_subblock_stats +from omegaconf import DictConfig + +from modelopt.torch._compress.replacement_library.build_replacement_library import ( + launch_build_replacement_library, +) +from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.utils.parsing import format_global_config + + +def launch_build_library_and_stats(cfg: DictConfig) -> None: + """ + Launch both build_replacement_library and calc_subblock_stats in sequence. + + Args: + cfg: Hydra configuration containing settings for both commands + """ + mprint("=" * 80) + mprint("STARTING UNIFIED BUILD LIBRARY AND STATS WORKFLOW") + mprint("=" * 80) + + # Step 1: Build replacement library + mprint("=" * 50) + mprint("STEP 1: Building Replacement Library") + mprint("=" * 50) + + try: + launch_build_replacement_library(cfg) + mprint("✅ Replacement library built successfully!") + except Exception as e: + mprint(f"❌ Failed to build replacement library: {e}") + raise + + # Step 2: Calculate subblock statistics + mprint("=" * 50) + mprint("STEP 2: Calculating Subblock Statistics") + mprint("=" * 50) + + try: + launch_calc_subblock_stats(cfg) + mprint("✅ Subblock statistics calculated successfully!") + except Exception as e: + mprint(f"❌ Failed to calculate subblock statistics: {e}") + raise + + mprint("=" * 80) + mprint("UNIFIED WORKFLOW COMPLETED SUCCESSFULLY! 🎉") + mprint("=" * 80) + + mprint("Generated files:") + mprint(f" - {cfg.puzzle_dir}/block_library.json") + mprint(f" - {cfg.puzzle_dir}/subblock_library.json") + mprint(f" - {cfg.puzzle_dir}/replacement_library.json") + mprint(f" - {cfg.puzzle_dir}/single_sequence_replacement_solutions.json") + mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.subblock_stats_filename}") + if hasattr(cfg.calc_subblock_stats, "moe_stats_filename"): + mprint(f" - {cfg.puzzle_dir}/{cfg.calc_subblock_stats.moe_stats_filename}") + + +@hydra.main("", version_base="1.3") +def main(cfg: DictConfig) -> None: + """ + Main entry point for the unified build library and stats command. + + This function uses Hydra for configuration management and runs both + build_replacement_library and calc_subblock_stats in sequence. + """ + cfg = hydra.utils.instantiate(cfg) + mprint("Unified Build Library and Stats Configuration:") + mprint(format_global_config(cfg)) + launch_build_library_and_stats(cfg) + + +if __name__ == "__main__": + register_hydra_resolvers() + main() diff --git a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py index 8fbf7c7c47..72c40f729f 100644 --- a/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py +++ b/modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py @@ -23,13 +23,13 @@ import datetime from pathlib import Path -import build_library_and_stats import mip_and_realize_models import scoring import torch from torch import nn import modelopt.torch._compress.pruning.pruning_ckpts as pruning_ckpts +from modelopt.torch._compress import build_library_and_stats from modelopt.torch._compress.activation_scoring import score_pruning_activations from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import ( convert_llama3_to_decilm, @@ -123,6 +123,7 @@ def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertR ) # Convert Llama3 model to DeciLM model + # TODO: Make it generic, do not call convert_llama3_to_decilm directly. if runtime.global_rank == 0: mprint("Compress Progress 2/8: converting model from HF to DeciLM (single-gpu)") hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable diff --git a/modelopt/torch/_compress/replacement_library/build_replacement_library.py b/modelopt/torch/_compress/replacement_library/build_replacement_library.py new file mode 100644 index 0000000000..a8b2b7f9b6 --- /dev/null +++ b/modelopt/torch/_compress/replacement_library/build_replacement_library.py @@ -0,0 +1,605 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module constructs the replacement library JSON files from a puzzle directory containing +multiple trained model checkpoints. It analyzes checkpoints to extract unique block and subblock +configurations, builds a library of available replacements, and generates solutions for layer +replacement in compressed models. The resulting replacement library can then be used by +ReplacementLibrary to efficiently load models with mixed teacher/student layers. + +Standard Puzzle Usage: +====================== +python -m modelopt.torch._compress.replacement_library.build_replacement_library PUZZLE_DIR + +Teacher checkpoint dir is assumed to be inside PUZZLE_DIR/ckpts/teacher (symlink is recommended) +though you can supply an explicit --teacher_checkpoint_dir. + +--add_ffn_no_ops and --add_attention_no_ops are optional (default True), + + +Untrained puzzle run (with bypass): +=================================== +The subblock that doesn't interest you in the checkpoint should be no_op. + +""" +# mypy: ignore-errors + +import json +from pathlib import Path +from typing import Any, Type + +import hydra +import pandas as pd +from omegaconf import DictConfig + +from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) +from modelopt.torch._compress.replacement_library.replacement_utils import ( + is_replacement_identical_to_teacher, + replacement_is_teacher, + sort_replacements, +) +from modelopt.torch._compress.tools.checkpoint_utils import ( + SAFETENSORS_SUBBLOCKS_DIR_NAME, + is_valid_decilm_checkpoint, + load_model_config, +) +from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers +from modelopt.torch._compress.tools.logger import mprint +from modelopt.torch._compress.tools.robust_json import json_dump +from modelopt.torch._compress.utils.parsing import format_global_config +from modelopt.torch._compress.utils.utils import block_config_to_str, subblock_config_to_str + +UNIQUE_SUBBLOCK_IDENTIFIER = ["block_config", "attention_config", "ffn_config", "block_idx"] +CHECKPOINTS_DIR_NAME = "ckpts" + + +def build_replacement_library( + master_puzzle_dir: Path | str, + teacher_checkpoint_dir: Path | str | None = None, + add_ffn_no_ops: bool = True, + add_attention_no_ops: bool = True, +) -> None: + """ + For normal puzzle runs, use default values. + For advanced use cases, see the Usage section. + """ + master_puzzle_dir = Path(master_puzzle_dir) + (master_puzzle_dir / "ckpts").mkdir(exist_ok=True) + teacher_checkpoint_dir = infer_teacher_dir(master_puzzle_dir, teacher_checkpoint_dir) + subblocks_df = _build_subblocks_df( + master_puzzle_dir, + teacher_checkpoint_dir, + add_ffn_no_ops, + add_attention_no_ops, + ) + block_library_df = _build_block_library_from_subblocks(subblocks_df) + + layer_replacements = _build_layer_replacements( + block_library_df, master_puzzle_dir, teacher_checkpoint_dir + ) + + single_sequence_replacement_solutions = _build_single_sequence_replacement_solutions( + layer_replacements, teacher_checkpoint_dir + ) + + json_dump(block_library_df.to_dict(orient="records"), master_puzzle_dir / "block_library.json") + json_dump(subblocks_df.to_dict(orient="records"), master_puzzle_dir / "subblock_library.json") + json_dump(layer_replacements, master_puzzle_dir / "replacement_library.json") + json_dump( + single_sequence_replacement_solutions, + master_puzzle_dir / "single_sequence_replacement_solutions.json", + ) + mprint("done") + + +def launch_build_replacement_library(cfg: DictConfig) -> None: + """ + Launch the build replacement library function with Hydra configuration. + """ + mprint(f"Building replacement library for puzzle directory: {cfg.puzzle_dir}") + mprint(f"Teacher directory: {cfg.teacher_dir}") + mprint( + f"Build replacement library config: {format_global_config(cfg.build_replacement_library, title='Build replacement library')}" + ) + + build_replacement_library( + master_puzzle_dir=cfg.puzzle_dir, + teacher_checkpoint_dir=cfg.teacher_dir, + add_ffn_no_ops=cfg.build_replacement_library.add_ffn_no_ops, + add_attention_no_ops=cfg.build_replacement_library.add_attention_no_ops, + ) + + +def infer_teacher_dir( + master_puzzle_dir: Path | str, + teacher_checkpoint_dir: Path | str | None = None, +) -> Path: + if teacher_checkpoint_dir is None: + teacher_checkpoint_dir = Path(master_puzzle_dir) / CHECKPOINTS_DIR_NAME / "teacher" + if not teacher_checkpoint_dir.exists(): + raise ValueError( + f"You must either provide the --teacher_checkpoint_dir argument, or create a link to the " + f"teacher dir under '{{PUZZLE_DIR}}/ckpts'." + ) + teacher_checkpoint_dir = Path(teacher_checkpoint_dir).resolve().absolute() + return teacher_checkpoint_dir + + +def _build_block_library_from_subblocks(subblocks_df: pd.DataFrame) -> pd.DataFrame: + joint_blocks_df = subblocks_df.dropna(subset=["block_config"]).copy() + constructed_blocks_df = _construct_blocks_from_subblocks(subblocks_df) + + is_constructed_block_has_joint_variant = pd.Series( + map(tuple, constructed_blocks_df[["block_config", "block_idx"]].values) + ).isin(pd.Series(map(tuple, joint_blocks_df[["block_config", "block_idx"]].values))) + constructed_blocks_df = constructed_blocks_df[~is_constructed_block_has_joint_variant] + + block_library_df = pd.concat([joint_blocks_df, constructed_blocks_df]) + block_library_df["block_repr"] = block_library_df["block_config"].apply(block_config_to_str) + + dups = block_library_df.loc[ + block_library_df[["block_config", "block_idx"]].duplicated() + ].sort_values(by=["block_config", "block_idx"]) + if len(dups) > 0: + mprint(f"Found {len(dups)} duplicate blocks in the block library. Here are some examples:") + dup_block_idx = dups["block_idx"].iloc[0] + dups_with_same_block_idx = dups[dups["block_idx"] == dup_block_idx] + for _, row in dups_with_same_block_idx.head(10).iterrows(): + mprint(row.to_dict()) + json_dump(block_library_df.to_dict(orient="records"), "ERROR_block_library.json") + json_dump(subblocks_df.to_dict(orient="records"), "ERROR_subblock_library.json") + raise ValueError( + f"Found {len(dups)} duplicate blocks in the block library. See ERROR_block_library.json and ERROR_subblock_library.json for more details." + ) + + return block_library_df + + +def _construct_blocks_from_subblocks(subblocks_df: pd.DataFrame) -> pd.DataFrame: + columns = subblocks_df.columns + decomp_blocks_df = subblocks_df[subblocks_df["block_config"].isna()].drop( + columns=columns[columns.str.contains("block_config|joint|block_repr")] + ) + + attention_df = decomp_blocks_df.dropna(subset="attention_config").drop( + columns=columns[columns.str.contains("ffn")] + ) + ffn_df = decomp_blocks_df.dropna(subset="ffn_config").drop( + columns=columns[columns.str.contains("attention")] + ) + constructed_blocks_df = pd.merge(attention_df, ffn_df, on="block_idx") + + constructed_blocks_df["block_config"] = constructed_blocks_df.apply( + lambda row: BlockConfig(ffn=row["ffn_config"], attention=row["attention_config"]), axis=1 + ) + + return constructed_blocks_df + + +def _build_subblocks_df( + master_puzzle_dir: Path | str, + teacher_checkpoint_dir: Path | str, + add_ffn_no_ops: bool, + add_attention_no_ops: bool, +) -> pd.DataFrame: + teacher_checkpoint_dir = Path(teacher_checkpoint_dir) + checkpoint_dirs = _get_last_checkpoint_from_each_experiment(master_puzzle_dir) + checkpoint_dirs = [teacher_checkpoint_dir] + list(checkpoint_dirs - {teacher_checkpoint_dir}) + checkpoints_to_split = [teacher_checkpoint_dir] + + subblock_rows = [] + for checkpoint_dir in checkpoint_dirs: + subblocks_to_extract = _infer_subblocks_to_extract(checkpoint_dir, checkpoints_to_split) + if len(subblocks_to_extract) > 0: + subblock_rows_from_current_checkpoint = ( + _construct_subblock_rows_from_current_checkpoint( + checkpoint_dir, subblocks_to_extract + ) + ) + subblock_rows.extend(subblock_rows_from_current_checkpoint) + + subblocks_df = pd.DataFrame(subblock_rows) + + subblocks_df = _drop_duplicates_of_decomp_no_op(subblocks_df) + assert subblocks_df.duplicated().sum() == 0 + + if add_ffn_no_ops or add_attention_no_ops: + subblocks_df = _add_no_op_subblock_rows(subblocks_df, add_ffn_no_ops, add_attention_no_ops) + + subblocks_df = _drop_duplicates_of_teacher(subblocks_df, teacher_checkpoint_dir) + + subblocks_that_have_multiple_sources = list( + subblocks_df[subblocks_df.duplicated(UNIQUE_SUBBLOCK_IDENTIFIER, keep=False)].groupby( + UNIQUE_SUBBLOCK_IDENTIFIER, dropna=False + ) + ) + if len(subblocks_that_have_multiple_sources) > 0: + mprint( + f"Found {len(subblocks_that_have_multiple_sources)} subblock types with multiple sources. Dropping duplicates..." + ) + for subblock_identifier, duplicates_df in subblocks_that_have_multiple_sources: + mprint("\n================================") + mprint(dict(zip(UNIQUE_SUBBLOCK_IDENTIFIER, subblock_identifier))) + for _, row in duplicates_df.iterrows(): + mprint(row.to_dict()) + + # Drop duplicates, keeping the first occurrence (which should be from teacher) + mprint(f"Dropping duplicates. Original count: {len(subblocks_df)}") + subblocks_df = subblocks_df.drop_duplicates(subset=UNIQUE_SUBBLOCK_IDENTIFIER, keep="first") + mprint(f"After dropping duplicates: {len(subblocks_df)}") + + subblocks_df["ffn_repr"] = subblocks_df["ffn_config"].apply(subblock_config_to_str) + subblocks_df["attention_repr"] = subblocks_df["attention_config"].apply(subblock_config_to_str) + subblocks_df["block_repr"] = subblocks_df["block_config"].apply(block_config_to_str) + + return subblocks_df + + +def _drop_duplicates_of_teacher( + subblocks_df: pd.DataFrame, + teacher_checkpoint_dir: Path | str, +) -> pd.DataFrame: + orig_subblocks_df = subblocks_df.copy() + + attention_is_teacher = subblocks_df["attention_checkpoint_dir"] == str(teacher_checkpoint_dir) + ffn_is_teacher = subblocks_df["ffn_checkpoint_dir"] == str(teacher_checkpoint_dir) + is_joint_teacher = attention_is_teacher & ffn_is_teacher + + is_decomp_attention = subblocks_df["ffn_config"].isna() + is_decomp_ffn = subblocks_df["attention_config"].isna() + is_joint_block = ~is_decomp_attention & ~is_decomp_ffn + + student_indices_that_have_teacher_dups = [] + + for current_subset, is_teacher in [ + (is_decomp_attention, attention_is_teacher), + (is_decomp_ffn, ffn_is_teacher), + (is_joint_block, is_joint_teacher), + ]: + subblocks_df = orig_subblocks_df.copy().loc[current_subset] + + subblocks_df["is_student"] = ~is_teacher.loc[current_subset] + + def get_student_indices_that_have_teacher_dups(grouped_is_student: pd.Series) -> list: + if grouped_is_student.all(): + return [] + return grouped_is_student.index[grouped_is_student].tolist() + + current_student_indices_that_have_teacher_dups = [ + dup_index + for dup_list in subblocks_df.groupby(UNIQUE_SUBBLOCK_IDENTIFIER, dropna=False)[ + "is_student" + ].apply(get_student_indices_that_have_teacher_dups) + for dup_index in dup_list + ] + student_indices_that_have_teacher_dups.extend( + current_student_indices_that_have_teacher_dups + ) + + dedup_subblocks_df = orig_subblocks_df.drop(index=student_indices_that_have_teacher_dups) + return dedup_subblocks_df + + +def _drop_duplicates_of_decomp_no_op(subblocks_df: pd.DataFrame) -> pd.DataFrame: + is_decomp = subblocks_df["block_config"].isna() + is_ffn_no_op = subblocks_df["ffn_config"].apply(lambda conf: conf is not None and conf.no_op) + is_attention_no_op = subblocks_df["attention_config"].apply( + lambda conf: conf is not None and conf.no_op + ) + is_duplicated = subblocks_df.duplicated(subset=UNIQUE_SUBBLOCK_IDENTIFIER, keep="first") + is_dup_of_decomp_no_op = is_duplicated & is_decomp & (is_ffn_no_op | is_attention_no_op) + subblocks_df = subblocks_df[~is_dup_of_decomp_no_op] + return subblocks_df + + +def _construct_subblock_rows_from_current_checkpoint( + checkpoint_dir: Path, subblocks_to_extract: list[str] +) -> list[dict[str, Any]]: + subblock_rows_from_current_checkpoint = [] + model_config = load_model_config(checkpoint_dir) + for block_idx, block_config in enumerate(model_config.block_configs): + for subblock_to_extract in subblocks_to_extract: + subblock_row = _init_empty_subblock_row(block_idx) + + if subblock_to_extract == "block": + subblock_row["block_config"] = block_config + subblock_row["attention_config"] = block_config.attention + subblock_row["attention_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.attention.no_op else None + ) + subblock_row["ffn_config"] = block_config.ffn + subblock_row["ffn_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.ffn.no_op else None + ) + elif subblock_to_extract == "ffn": + subblock_row["ffn_config"] = block_config.ffn + subblock_row["ffn_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.ffn.no_op else None + ) + elif subblock_to_extract == "attention": + subblock_row["attention_config"] = block_config.attention + subblock_row["attention_checkpoint_dir"] = ( + str(checkpoint_dir) if not block_config.attention.no_op else None + ) + else: + raise ValueError() + + subblock_rows_from_current_checkpoint.append(subblock_row) + return subblock_rows_from_current_checkpoint + + +def _add_no_op_subblock_rows( + subblocks_df: pd.DataFrame, + add_ffn_no_op: bool, + add_attention_no_op: bool, +) -> pd.DataFrame: + n_layer = subblocks_df["block_idx"].max() + 1 + + no_op_subblocks = [] + if add_ffn_no_op: + no_op_subblocks.append("ffn") + if add_attention_no_op: + no_op_subblocks.append("attention") + + additional_no_op_rows = [] + for no_op_subblock in no_op_subblocks: + rows_with_no_op_subblock, subblock_cls = _get_rows_with_no_op_subblock( + subblocks_df, no_op_subblock + ) + existing_no_op_indices = rows_with_no_op_subblock["block_idx"].values + missing_no_op_indices = list(set(range(n_layer)) - set(existing_no_op_indices)) + for block_idx in missing_no_op_indices: + no_op_subblock_row = { + **_init_empty_subblock_row(block_idx), + f"{no_op_subblock}_config": subblock_cls(no_op=True), + } + additional_no_op_rows.append(no_op_subblock_row) + + subblocks_df = pd.concat([subblocks_df, pd.DataFrame(additional_no_op_rows)]) + + for no_op_subblock in no_op_subblocks: + rows_with_no_op_subblock, _ = _get_rows_with_no_op_subblock(subblocks_df, no_op_subblock) + assert len(rows_with_no_op_subblock) == n_layer, ( + f"Got {len(rows_with_no_op_subblock)} rows with {no_op_subblock}=no_op, but we have {n_layer} layers" + ) + return subblocks_df + + +def _get_rows_with_no_op_subblock( + subblocks_df: pd.DataFrame, no_op_subblock: str +) -> tuple[pd.DataFrame, Type[AttentionConfig] | Type[FFNConfig]]: + other_subblock = "ffn" if no_op_subblock == "attention" else "attention" + subblock_cls = AttentionConfig if no_op_subblock == "attention" else FFNConfig + no_op_subblock_config = subblock_cls(no_op=True) + rows_with_no_op_subblock = subblocks_df[ + (subblocks_df[f"{no_op_subblock}_config"] == no_op_subblock_config) + & subblocks_df[f"{other_subblock}_config"].isna() + ] + return rows_with_no_op_subblock, subblock_cls + + +def _get_last_checkpoint_from_each_experiment(master_puzzle_dir: Path | str) -> set[Path]: + master_puzzle_dir = Path(master_puzzle_dir) + master_checkpoints_dir = master_puzzle_dir / CHECKPOINTS_DIR_NAME + subdirs_of_master_checkpoints_dir = [ + p.resolve() for p in master_checkpoints_dir.iterdir() if p.is_dir() + ] + checkpoint_dirs = [ + p.parent + for subdir in subdirs_of_master_checkpoints_dir + for p in subdir.rglob("config.json") + ] + + for checkpoint_dir in checkpoint_dirs: + if checkpoint_dir == master_checkpoints_dir: + raise ValueError( + f"We need at least 1 hierarchy level under the '{CHECKPOINTS_DIR_NAME}' dir. " + "Name your checkpoints, preferably with meaningful names. " + "If you are Ido Galil, tell Tomer that you got this exception ;) " + ) + + # Filter out non-DeciLM checkpoints (e.g., unconverted Llama checkpoints) + valid_checkpoint_dirs = [cp for cp in checkpoint_dirs if is_valid_decilm_checkpoint(cp)] + + experiment_dirs = [ + p if (p in subdirs_of_master_checkpoints_dir) else p.parent for p in valid_checkpoint_dirs + ] + + deduped_checkpoint_dirs = set( + pd.DataFrame({"checkpoint_dir": valid_checkpoint_dirs, "experiment_dir": experiment_dirs}) + .sort_values("checkpoint_dir") + .drop_duplicates(subset="experiment_dir", keep="last")["checkpoint_dir"] + .tolist() + ) + return deduped_checkpoint_dirs + + +def _infer_subblocks_to_extract( + checkpoint_dir: Path, + checkpoints_to_split: list[Path], +) -> list[str]: + if (checkpoint_dir / "replacement_library.json").exists(): + return [] + bypass_config_path = checkpoint_dir / "bypass_config.json" + if (checkpoint_dir in checkpoints_to_split) or (not bypass_config_path.exists()): + subblocks_to_extract = ["block", "attention", "ffn"] + else: + bypass_config = json.loads(bypass_config_path.read_text()) + keys_to_learn = bypass_config.get("keys_to_learn", "entire_block") + if keys_to_learn == "entire_block": + subblocks_to_extract = ["block"] + elif "mlp" in keys_to_learn and "attn" not in keys_to_learn: + subblocks_to_extract = ["ffn"] + elif "attn" in keys_to_learn and "mlp" not in keys_to_learn: + subblocks_to_extract = ["attention"] + else: + raise ValueError(f"Unrecognized {keys_to_learn=}") + return subblocks_to_extract + + +def _init_empty_subblock_row(block_idx: int) -> dict[str, Any]: + return { + "attention_checkpoint_dir": None, + "ffn_checkpoint_dir": None, + "block_config": None, + "attention_config": None, + "ffn_config": None, + "block_idx": block_idx, + "block_repr": None, + "attention_repr": None, + "ffn_repr": None, + } + + +def _build_layer_replacements( + block_library_df: pd.DataFrame, + master_puzzle_dir: Path, + teacher_checkpoint_dir: Path, +) -> list[dict]: + layer_replacements_from_blocks = _build_layer_replacements_from_block_library(block_library_df) + layer_replacements_from_checkpoints = _gather_layer_replacements_from_checkpoints( + master_puzzle_dir + ) + layer_replacements = layer_replacements_from_blocks + layer_replacements_from_checkpoints + layer_replacements = _filter_duplicate_teacher_replacements( + layer_replacements, teacher_checkpoint_dir + ) + return layer_replacements + + +def _build_layer_replacements_from_block_library(block_library_df: pd.DataFrame) -> list[dict]: + layer_replacements = [] + for _, row in block_library_df.iterrows(): + block_idx = row["block_idx"] + block_config = row["block_config"] + weight_paths = [] + for subblock_name in ["attention", "ffn"]: + checkpoint_dir = row[f"{subblock_name}_checkpoint_dir"] + if checkpoint_dir is not None: + subblock_path = ( + Path(checkpoint_dir) + / SAFETENSORS_SUBBLOCKS_DIR_NAME + / f"block_{block_idx}_{subblock_name}.safetensors" + ) + weight_paths.append(subblock_path) + weight_paths = sorted(set(weight_paths)) + layer_replacement = { + "parent_layer_indices": [block_idx], + "child_block_configs": [block_config], + "weight_paths": weight_paths, + } + layer_replacements.append(layer_replacement) + return layer_replacements + + +def _gather_layer_replacements_from_checkpoints(master_puzzle_dir: str | Path) -> list[dict]: + gathered_layer_replacements = [] + checkpoint_dirs = _get_last_checkpoint_from_each_experiment(master_puzzle_dir) + for checkpoint_dir in checkpoint_dirs: + if (layer_replacements_path := checkpoint_dir / "replacement_library.json").exists(): + layer_replacements = json.loads(layer_replacements_path.read_text()) + for layer_replacement in layer_replacements: + layer_replacement["child_block_configs"] = [ + BlockConfig(**block_config_dict) + for block_config_dict in layer_replacement["child_block_configs"] + ] + layer_replacement["weight_paths"] = sorted( + set(Path(p) for p in layer_replacement["weight_paths"]) + ) + gathered_layer_replacements.extend(layer_replacements) + return gathered_layer_replacements + + +def _filter_duplicate_teacher_replacements( + layer_replacements: list[dict], + teacher_checkpoint_dir: Path, +) -> list[dict]: + teacher_model_config = load_model_config(teacher_checkpoint_dir) + filtered_layer_replacements = [] + for layer_replacement in layer_replacements: + if replacement_is_teacher( + layer_replacement, teacher_model_config, teacher_checkpoint_dir + ) or not is_replacement_identical_to_teacher(layer_replacement, teacher_model_config): + filtered_layer_replacements.append(layer_replacement) + return filtered_layer_replacements + + +def _build_single_sequence_replacement_solutions( + layer_replacements: list[dict], + teacher_checkpoint_dir: Path, +) -> list[dict]: + teacher_model_config = load_model_config(teacher_checkpoint_dir) + n_layer = teacher_model_config.num_hidden_layers + + teacher_replacements = dict() + student_replacements = [] + for layer_replacement in layer_replacements: + if replacement_is_teacher(layer_replacement, teacher_model_config, teacher_checkpoint_dir): + block_idx = layer_replacement["parent_layer_indices"][0] + teacher_replacements[block_idx] = layer_replacement + else: + student_replacements.append(layer_replacement) + + teacher_indices_represented_in_replacements = sorted(teacher_replacements.keys()) + assert teacher_indices_represented_in_replacements == list(range(n_layer)), ( + f"{n_layer=}, {teacher_indices_represented_in_replacements=}" + ) + + student_replacements = sort_replacements(student_replacements) + + solutions = [] + for layer_replacement in student_replacements: + block_indices_not_represented_in_replacement = sorted( + set(range(n_layer)) - set(layer_replacement["parent_layer_indices"]) + ) + chosen_replacements = sort_replacements( + [layer_replacement] + + [ + teacher_replacements[block_idx] + for block_idx in block_indices_not_represented_in_replacement + ] + ) + + block_configs = [ + block_config + for replacement in chosen_replacements + for block_config in replacement["child_block_configs"] + ] + + solutions.append( + { + "single_sequence_replacement": layer_replacement, + "chosen_replacements": chosen_replacements, + "block_configs": block_configs, + } + ) + + return solutions + + +@hydra.main("", version_base="1.3") +def main(cfg: DictConfig) -> None: + cfg = hydra.utils.instantiate(cfg) + mprint(format_global_config(cfg)) + launch_build_replacement_library(cfg) + + +if __name__ == "__main__": + register_hydra_resolvers() + main() diff --git a/modelopt/torch/_compress/replacement_library/replacement_library.py b/modelopt/torch/_compress/replacement_library/replacement_library.py new file mode 100644 index 0000000000..ccfaaee0de --- /dev/null +++ b/modelopt/torch/_compress/replacement_library/replacement_library.py @@ -0,0 +1,388 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Replacement library for efficiently loading and managing layer-replaced DeciLM models. +- Uses replacement_utils for parsing, sorting, and analyzing layer replacement configurations +""" +# mypy: ignore-errors + +import json +import re +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +from immutabledict import immutabledict +from lru import LRU +from safetensors.torch import load_file as safe_load_file +from torch import nn + +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import ( + DeciLMDecoderLayer, + DeciLMForCausalLM, + DeciLMMultiDecoderLayer, + DeciLMRMSNorm, + LMHead, +) +from modelopt.torch._compress.replacement_library.replacement_utils import ( + extract_block_configs_and_locations, + parse_layer_replacement, + sort_replacements, + weights_path_to_checkpoint_dir, +) +from modelopt.torch._compress.tools.checkpoint_utils import ( + PTH_SUBBLOCKS_DIR_NAME, + SAFETENSORS_SUBBLOCKS_DIR_NAME, + infer_weights_dtype, + init_empty_module, + init_module_with_state_dict, + load_model_config, +) +from modelopt.torch._compress.tools.sharded_checkpoint_utils import ( + create_dummy_model, + is_in_safetensors_format, + load_sharded_state_dict, +) + + +class ReplacementLibrary: + def __init__( + self, + replacement_library_path: str | Path, + model_config_overrides: Optional[dict] = None, + ): + self.replacement_library = self._load_replacement_library(replacement_library_path) + self._ensure_all_checkpoints_are_split() + self.model_config_overrides = ( + immutabledict(model_config_overrides) if (model_config_overrides is not None) else None + ) + + self._loaded_replacements: dict[str, nn.ModuleList] = LRU( + size=256 + ) # least-recently-used dict: a dict of fixed size that evicts old items + + self._dtype = None + + self.teacher_dir = Path(replacement_library_path).parent / "ckpts" / "teacher" + self._model_config = None + self._embedding = None + self._ln_f = None + self._lm_head = None + self._arbitrary_checkpoint_dir = None + + @staticmethod + def _load_replacement_library(replacement_library_path: str | Path) -> list[dict]: + replacement_library = json.loads(Path(replacement_library_path).read_text()) + replacement_library = [ + parse_layer_replacement(layer_replacement) for layer_replacement in replacement_library + ] + return replacement_library + + def _ensure_all_checkpoints_are_split(self) -> None: + checkpoint_dirs = self._get_all_checkpoint_dirs() + unsplit_checkpoints = [] + for checkpoint_dir in checkpoint_dirs: + if not (checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists(): + unsplit_checkpoints.append(checkpoint_dir) + assert len(unsplit_checkpoints) == 0, f"Found unsplit checkpoints: {unsplit_checkpoints}" + + @property + def dtype(self) -> torch.dtype: + if self._dtype is None: + ln_f = self.get_ln_f() + self._dtype = ln_f.weight.dtype + return self._dtype + + @property + def n_layer(self) -> int: + return self.model_config.get_num_hidden_layers() + + @property + def model_config(self) -> DeciLMConfig: + if self._model_config is None: + self._model_config = load_model_config( + self.get_arbitrary_checkpoint_dir(), self.model_config_overrides + ) + return self._model_config + + def create_model_config(self, layer_replacements: list[dict]): + block_configs, _ = extract_block_configs_and_locations(layer_replacements) + model_config = self.model_config.set_block_configs(block_configs) + return model_config + + def load_model( + self, + layer_replacements: list[dict], + world_size: int, + global_rank: int, + ) -> DeciLMForCausalLM: + block_configs, block_locations = extract_block_configs_and_locations(layer_replacements) + model_config = self.model_config.set_block_configs(block_configs) + + owned_block_indexes = _get_owned_block_indexes( + model_config.get_num_hidden_layers(), world_size, global_rank + ) + model = create_dummy_model(model_config, self.dtype) + + is_first_shard = 0 in owned_block_indexes + if is_first_shard and not isinstance(model.model.get_input_embeddings(), nn.Embedding): + model.set_input_embeddings(self.get_embedding()) + + is_last_shard = model_config.get_num_hidden_layers() - 1 in owned_block_indexes + if is_last_shard and not isinstance(model.model.get_output_embeddings(), nn.Linear): + model.model.set_final_layer_norm(self.get_ln_f()) + model.set_output_embeddings(self.get_lm_head()) + + active_blocks = [] + for block_idx in owned_block_indexes: + layer_replacement, block_idx_in_replacement = block_locations[block_idx] + block = self.get_block(layer_replacement, block_idx_in_replacement) + model.model.layers[block_idx] = block + active_blocks.append(block) + + self._move_inactive_blocks_to_cpu(active_blocks) + return model + + def load_checkpoint( + self, + checkpoint_dir: str | Path, + world_size: int, + global_rank: int, + ) -> DeciLMForCausalLM: + checkpoint_dir = Path(checkpoint_dir).resolve() + layer_replacements = self._locate_replacements_of_entire_checkpoint(checkpoint_dir) + model = self.load_model(layer_replacements, world_size, global_rank) + return model + + def _locate_replacements_of_entire_checkpoint(self, checkpoint_dir: str | Path) -> list[dict]: + weight_paths_located = [] + layer_replacements = [] + for layer_replacement in self.replacement_library: + weight_paths = layer_replacement["weight_paths"] + weight_paths = [Path(p).absolute().resolve() for p in weight_paths] + layer_replacement["weight_paths"] = weight_paths + if len(weight_paths) > 0 and all( + p.is_relative_to(checkpoint_dir) for p in weight_paths + ): + layer_replacements.append(layer_replacement) + weight_paths_located.extend(weight_paths) + + all_block_weight_paths = [ + p + for p in list((checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).iterdir()) + if p.name not in ("embeddings.safetensors", "lm_head.safetensors") + ] + missing_paths = set(all_block_weight_paths) - set(weight_paths_located) + assert len(missing_paths) == 0, ( + f"Couldn't locate replacements for the entire checkpoint {checkpoint_dir}, missing weights: {missing_paths}" + ) + + dedupped_layer_replacements = [] + for weights_path in all_block_weight_paths: + replacements_with_path = [ + rep for rep in layer_replacements if weights_path in rep["weight_paths"] + ] + largets_replacement_with_path = max( + replacements_with_path, key=lambda rep: len(rep["weight_paths"]) + ) + if largets_replacement_with_path not in dedupped_layer_replacements: + dedupped_layer_replacements.append(largets_replacement_with_path) + + dedupped_layer_replacements = sort_replacements(dedupped_layer_replacements) + return dedupped_layer_replacements + + def get_block( + self, layer_replacement: dict, block_idx_in_replacement: int + ) -> DeciLMDecoderLayer | DeciLMMultiDecoderLayer: + if str(layer_replacement) not in self._loaded_replacements.keys(): + self._loaded_replacements[str(layer_replacement)] = self._load_layer_replacement( + layer_replacement + ) + module_list = self._loaded_replacements[str(layer_replacement)] + block = module_list[block_idx_in_replacement] + return block + + def _load_layer_replacement(self, layer_replacement: dict) -> nn.ModuleList: + state_dict = dict() + for weights_path in layer_replacement["weight_paths"]: + if weights_path.suffix == ".safetensors": + curr_state_dict = safe_load_file(weights_path) + elif weights_path.suffix == ".pth": + curr_state_dict = torch.load(weights_path, weights_only=True) + else: + raise ValueError(f"Unrecognized suffix of {weights_path=}") + for param_name in curr_state_dict.keys(): + assert param_name not in state_dict, ( + f"Duplicate entries for {param_name=} in {layer_replacement=}" + ) + state_dict.update(curr_state_dict) + + if len(state_dict) > 0: + block_indices = [ + int(re.findall(r"^model\.layers\.(\d+)\.", param_name)[0]) + for param_name in state_dict.keys() + ] + assert sorted(set(block_indices)) == list( + range(min(block_indices), max(block_indices) + 1) + ), ( + f"Block indices in loaded weight files must be consecutive, but found {sorted(set(block_indices))} in {layer_replacement=}" + ) + + min_block_idx = min(block_indices) + + state_dict = { + param_name.replace( + f"model.layers.{block_idx}.", f"{block_idx - min_block_idx}." + ): param_weight + for block_idx, (param_name, param_weight) in zip(block_indices, state_dict.items()) + } + + dtype = infer_weights_dtype(state_dict) + model_config = self.model_config.set_block_configs(layer_replacement["child_block_configs"]) + + module_list = nn.ModuleList( + [ + ( + init_empty_module(DeciLMDecoderLayer, dtype, model_config, layer_idx) + if (block_config.parallel_blocks is None) + else init_empty_module(DeciLMMultiDecoderLayer, dtype, model_config, layer_idx) + ) + for layer_idx, block_config in enumerate(layer_replacement["child_block_configs"]) + ] + ) + + module_list.load_state_dict(state_dict, strict=True) + return module_list + + def _move_inactive_blocks_to_cpu(self, active_blocks: list[nn.Module]) -> None: + for module_list in self._loaded_replacements.values(): + for module in module_list: + if module not in active_blocks: + module.to("cpu") + + def get_embedding(self) -> nn.Embedding: + if self._embedding is None: + state_dict = { + "weight": self._get_arbitrary_non_block_param( + self.model_config.get_embedding_layer_name() + ".weight" + ) + } + self._embedding = init_module_with_state_dict( + state_dict, + nn.Embedding, + num_embeddings=self.model_config.vocab_size, + embedding_dim=self.model_config.hidden_size, + ) + return self._embedding + + def get_ln_f(self) -> DeciLMRMSNorm: + if self._ln_f is None: + state_dict = { + "weight": self._get_arbitrary_non_block_param( + self.model_config.get_final_layer_norm_layer_name() + ".weight" + ) + } + self._ln_f = init_module_with_state_dict( + state_dict, + DeciLMRMSNorm, + hidden_size=self.model_config.hidden_size, + eps=self.model_config.rms_norm_eps, + ) + return self._ln_f + + def get_lm_head(self) -> nn.Linear: + if self._lm_head is None: + state_dict = { + "weight": self._get_arbitrary_non_block_param( + self.model_config.get_lm_head_layer_name() + ".weight" + ) + } + self._lm_head = init_module_with_state_dict( + state_dict, + LMHead, + out_features=self.model_config.vocab_size, + in_features=self.model_config.hidden_size, + bias=False, + ) + return self._lm_head + + def _get_arbitrary_non_block_param(self, param_name: str) -> torch.Tensor: + checkpoint_dir = self.get_arbitrary_checkpoint_dir() + if ( + is_in_safetensors_format(checkpoint_dir) + or (checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists() + ): + partial_state_dict = load_sharded_state_dict(checkpoint_dir, [param_name]) + return partial_state_dict[param_name] + + non_block_pth_path = checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME / f"non_block.pth" + assert non_block_pth_path.exists(), _error_message_ensure_split(checkpoint_dir) + non_block_state_dict = torch.load(non_block_pth_path) + return non_block_state_dict[param_name] + + def get_arbitrary_checkpoint_dir(self) -> Path: + if self._arbitrary_checkpoint_dir is None: + self._arbitrary_checkpoint_dir = self._get_arbitrary_checkpoint_dir() + return self._arbitrary_checkpoint_dir + + def get_teacher_dir(self) -> Path: + return self.teacher_dir + + def get_teacher_lm_head_path(self) -> Path: + return self.get_teacher_dir() / SAFETENSORS_SUBBLOCKS_DIR_NAME / "lm_head.safetensors" + + def get_teacher_embedding_path(self) -> Path: + return self.get_teacher_dir() / SAFETENSORS_SUBBLOCKS_DIR_NAME / "embeddings.safetensors" + + def _get_arbitrary_checkpoint_dir(self) -> Path: + for layer_replacement in self.replacement_library: + weight_paths = layer_replacement["weight_paths"] + if len(weight_paths) > 0: + return weights_path_to_checkpoint_dir(weight_paths[0]) + + def _get_all_checkpoint_dirs(self) -> list[Path]: + checkpoint_dirs = set() + for layer_replacement in self.replacement_library: + weight_paths = layer_replacement["weight_paths"] + for weights_path in weight_paths: + checkpoint_dir = weights_path_to_checkpoint_dir(weights_path) + checkpoint_dirs.add(checkpoint_dir) + return list(checkpoint_dirs) + + +def _error_message_ensure_split(checkpoint_dir: Path) -> str: + return ( + f"Encountered unsplit checkpoint dir '{checkpoint_dir}', " + f"please call `ensure_all_checkpoints_are_split`" + ) + + +def _get_owned_block_indexes(n_layer: int, world_size: int, global_rank: int) -> list[int]: + last_process_blocks = np.array([n_layer - 1]) # less params in last gpu, leave room for logits + + if world_size == 1: + # Only one process: assign everything (including the "last process" block) to rank 0 + owned_block_indexes_per_process = [ + np.concatenate([np.arange(n_layer - 1), last_process_blocks]) + ] + else: + # Multiple processes: split n_layer-1 blocks, reserve the last for "last process" + owned_block_indexes_per_process = np.array_split(range(n_layer - 1), world_size - 1) + owned_block_indexes_per_process.append(last_process_blocks) + + owned_block_indexes = owned_block_indexes_per_process[global_rank].tolist() + return owned_block_indexes diff --git a/modelopt/torch/_compress/replacement_library/replacement_utils.py b/modelopt/torch/_compress/replacement_library/replacement_utils.py new file mode 100644 index 0000000000..21ae411752 --- /dev/null +++ b/modelopt/torch/_compress/replacement_library/replacement_utils.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module provides helper functions for parsing, sorting, and analyzing layer replacement +configurations used in the replacement library for model compression. +""" + +# mypy: ignore-errors +import json +from copy import deepcopy +from pathlib import Path + +from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import BlockConfig +from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig + + +def parse_layer_replacement(layer_replacement: dict | str) -> dict: + if isinstance(layer_replacement, str): + layer_replacement = json.loads(layer_replacement) + else: + layer_replacement = deepcopy(layer_replacement) + + if "layer_replacement" in layer_replacement: # happens in puzzle solutions + layer_replacement = layer_replacement["layer_replacement"] + + layer_replacement["child_block_configs"] = [ + BlockConfig(**block_config) if isinstance(block_config, dict) else block_config + for block_config in layer_replacement["child_block_configs"] + ] + layer_replacement["weight_paths"] = [Path(p) for p in layer_replacement["weight_paths"]] + return layer_replacement + + +def sort_replacements(layer_replacements: list[dict]) -> list[dict]: + return sorted(layer_replacements, key=lambda replacement: replacement["parent_layer_indices"]) + + +def extract_block_configs_and_locations( + layer_replacements: list[dict], +) -> tuple[list[BlockConfig], list[tuple[dict, int]]]: + layer_replacements = sort_replacements(layer_replacements) + block_configs = [] + block_locations = [] + for layer_replacement in layer_replacements: + child_block_configs = layer_replacement["child_block_configs"] + if not isinstance(child_block_configs, list | tuple): + child_block_configs = [child_block_configs] + for block_idx_in_replacement, block_config in enumerate(child_block_configs): + block_configs.append(block_config) + block_locations.append((layer_replacement, block_idx_in_replacement)) + return block_configs, block_locations + + +def weights_path_to_checkpoint_dir(weights_path: Path) -> Path: + checkpoint_dir: Path = weights_path + while checkpoint_dir != Path("/"): + if (checkpoint_dir / "config.json").exists(): + return checkpoint_dir + checkpoint_dir = checkpoint_dir.parent + raise FileNotFoundError(f"Couldn't find checkpoint dir for weights path {weights_path}") + + +def replacement_is_teacher( + layer_replacement: dict, + teacher_model_config: DeciLMConfig, + teacher_checkpoint_dir: Path, +) -> bool: + paths_all_teacher = all( + p.is_relative_to(teacher_checkpoint_dir) for p in layer_replacement["weight_paths"] + ) + return paths_all_teacher and is_replacement_identical_to_teacher( + layer_replacement, teacher_model_config + ) + + +def is_replacement_identical_to_teacher( + layer_replacement: dict, + teacher_model_config: DeciLMConfig, +) -> bool: + if len(layer_replacement["parent_layer_indices"]) == 1: + block_idx = layer_replacement["parent_layer_indices"][0] + teacher_block_config = teacher_model_config.block_configs[block_idx] + if len(child_block_configs := layer_replacement["child_block_configs"]) == 1: + replacement_block_config: BlockConfig = child_block_configs[0] + if replacement_block_config == teacher_block_config: + return True + else: + parallel_blocks = getattr(replacement_block_config, "parallel_blocks", None) + if ( + parallel_blocks is not None + and len(parallel_blocks) == 1 + and parallel_blocks[0].attention == teacher_block_config.attention + and parallel_blocks[0].ffn == teacher_block_config.ffn + ): + return True + return False + + +def split_replacements_to_teacher_and_student( + replacements: list[dict], + teacher_model_config: DeciLMConfig, + teacher_checkpoint_dir: Path, +) -> tuple[list[dict], list[dict]]: + teacher_replacements, student_replacements = [], [] + for replacement in replacements: + if replacement_is_teacher(replacement, teacher_model_config, teacher_checkpoint_dir): + teacher_replacements.append(replacement) + else: + student_replacements.append(replacement) + return teacher_replacements, student_replacements diff --git a/modelopt/torch/_compress/utils/utils.py b/modelopt/torch/_compress/utils/utils.py index 6e2ba9339a..74329bcd0a 100644 --- a/modelopt/torch/_compress/utils/utils.py +++ b/modelopt/torch/_compress/utils/utils.py @@ -13,8 +13,106 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses +from typing import Any + import torch +from modelopt.torch._compress.decilm.deci_lm_hf_code.block_config import ( + AttentionConfig, + BlockConfig, + FFNConfig, +) + + +def block_config_to_str(block_config: BlockConfig | dict[str, Any] | None) -> str | None: + """ + Convert a BlockConfig to a human-readable string representation. + + TODO: Consider a better place for this function. + Args: + block_config: BlockConfig dataclass or dict containing attention and ffn configs. + + Returns: + Formatted string with attention and FFN information, or None if input is None. + """ + if block_config is None: + return None + rep = "" + if dataclasses.is_dataclass(block_config): + block_config = dataclasses.asdict(block_config) + for subblock_name in ["attention", "ffn"]: + subblock_config = block_config[subblock_name] + rep += subblock_config_to_str(subblock_config, subblock_name) + return rep + + +def subblock_config_to_str( + subblock_config: FFNConfig | AttentionConfig | dict[str, Any] | None, + subblock_name: None | str = None, +) -> str | None: + """Convert a subblock config (FFN, Attention, Mamba, or MoE) to string. + + TODO: Consider a better place for this function. + Args: + subblock_config: FFNConfig, AttentionConfig dataclass or dict. + subblock_name: Name of subblock ('ffn', 'attention', 'mamba', 'moe'). + Auto-detected if subblock_config is a dataclass. + + Returns: + Formatted string showing subblock type and key parameters (e.g., intermediate_size, + n_heads_in_group), or None if input is None. + """ + if subblock_config is None: + return None + subblock_name = ( + "ffn" + if isinstance(subblock_config, FFNConfig) + else "mamba" + if isinstance(subblock_config, AttentionConfig) and subblock_config.is_mamba + else "attention" + if isinstance(subblock_config, AttentionConfig) + else subblock_name + ) + assert subblock_name is not None, "Must provide subblock_name if subblock_config is a dict." + + if dataclasses.is_dataclass(subblock_config): + subblock_config = dataclasses.asdict(subblock_config) + + if subblock_name == "attention" and subblock_config.get("mamba") is not None: + subblock_name = "mamba" + + if subblock_name == "ffn" and subblock_config.get("moe") is not None: + subblock_name = "moe" + + rep = f" {subblock_name}" + if subblock_config.get("no_op"): + rep += " no_op".ljust(8) + elif subblock_config.get("replace_with_linear"): + rep += " linear".ljust(8) + elif subblock_name == "ffn": + intermediate_size = subblock_config["intermediate_size"] + rep += f" intermediate_{intermediate_size}".ljust(8) + elif subblock_name == "attention": + n_heads_in_group = subblock_config["n_heads_in_group"] + rep += f" gqa_{n_heads_in_group}".ljust(8) + elif subblock_name == "mamba": + mamba_num_heads = subblock_config["mamba"]["num_heads"] + mamba_head_dim = subblock_config["mamba"]["head_dim"] + rep += f" num_heads_{mamba_num_heads} head_dim_{mamba_head_dim}".ljust(8) + elif subblock_name == "moe": + moe_num_local_experts = subblock_config["moe"]["num_local_experts"] + moe_expert_intermediate_dim = subblock_config["moe"]["expert_intermediate_dim"] + shared_expert_intermediate_dim = subblock_config["moe"]["shared_expert_intermediate_dim"] + num_experts_per_tok = subblock_config["moe"]["num_experts_per_tok"] + rep += f" num_experts_{moe_num_local_experts} expert_intermediate_dim_{moe_expert_intermediate_dim} shared_expert_intermediate_dim_{shared_expert_intermediate_dim} num_experts_per_tok_{num_experts_per_tok}".ljust( + 8 + ) + else: + raise ValueError(f"subblock_config_to_str: unrecognized subblock_name: {subblock_name}.") + + return rep + class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): def __init__(self, device=None, dtype=None): diff --git a/setup.py b/setup.py index ab70cdf68a..3eb41967d1 100644 --- a/setup.py +++ b/setup.py @@ -108,6 +108,8 @@ "wandb~=0.17.5", "lru-dict", "typeguard", + "pandas", + "immutabledict", ], }