diff --git a/modelopt/torch/_compress/README.md b/modelopt/torch/_compress/README.md new file mode 100644 index 0000000000..4c6da80e54 --- /dev/null +++ b/modelopt/torch/_compress/README.md @@ -0,0 +1,3 @@ +Experimental model compression algorithm based on a Local Neural Architecture Search. +Based on the Puzzle paper: +PoC for Llama 3.1 model. diff --git a/modelopt/torch/_compress/compress.py b/modelopt/torch/_compress/compress.py new file mode 100644 index 0000000000..265fd5eeb2 --- /dev/null +++ b/modelopt/torch/_compress/compress.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + +This module provides the main compression function for a model +using MIP-based NAS search algorithm. + +""" + +import build_library_and_stats +import mip_and_realize_models +import pruning_ckpts +import score_pruning_activations +import scoring +from omegaconf import DictConfig +from puzzle_tools.runtime import IRuntime + +# TODO Move initialize_hydra_config_for_dir from tests to main +from tests.utils.test_utils import initialize_hydra_config_for_dir + + +def compress( + hydra_config_dir: str, hydra_config: str, puzzle_dir: str, dataset_path: str, runtime: IRuntime +) -> DictConfig: + """Compress a puzzletron model using the MIP-based NAS search algorithm. + + Args: + hydra_config_dir (str): path to a hydra_config_dir that defines the search space + hydra_config (str): the corresponding hydra config file + puzzle_dir (str): directory with a puzzletron model to compress + dataset_path (str): dataset used for scoring and distillation + runtime: distributed runtime to use to run the compression steps, e.g., + NativeDdpRuntime(dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10)) + + Returns: + Hydra config object after compressing the model. + The same hydra configuration object is used across all compression steps. + @TODO: Investigate if this config object is immutable across steps and clarify + """ + # Step 0: Load puzzletron hydra config + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config, + overrides=[ + f"puzzle_dir={puzzle_dir}", + f"dataset_path={dataset_path}", + ], + ) + + # Step 1: score_pruning_activations (distributed processing) + score_pruning_activations.launch_score_activations(hydra_cfg, runtime) + + # Step 2: pruning_ckpts (single process) + if runtime.global_rank == 0: + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + runtime.wait_for_everyone() + + # Step 4: build_library_and_stats (single process) + if runtime.global_rank == 0: + build_library_and_stats.launch_build_library_and_stats(hydra_cfg) + runtime.wait_for_everyone() + + # Step 5: calc_one_block_scores (distributed processing) + scoring.launch_scoring(hydra_cfg, runtime) + + # Step 6: mip_and_realize_models (distributed processing) + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg, runtime) + + return hydra_cfg diff --git a/modelopt/torch/_compress/runtime.py b/modelopt/torch/_compress/runtime.py new file mode 100644 index 0000000000..46f561a5d9 --- /dev/null +++ b/modelopt/torch/_compress/runtime.py @@ -0,0 +1,556 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes for torch distributed runtime management""" + +import os +import random +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Iterator, Sequence +from contextlib import AbstractContextManager, suppress +from datetime import timedelta +from pathlib import Path +from typing import Literal, TypeVar, cast + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +from torch.utils.data import DataLoader +from tqdm import tqdm +from typing_extensions import override + +PrepareModelsT = TypeVar("PrepareModelsT", bound=Sequence[nn.Module]) +PrepareDataLoaderT = TypeVar("PrepareDataLoaderT", bound=DataLoader) +CompileT = TypeVar("CompileT", bound=nn.Module) +Filter = ( + Literal["main_process", "last", "local_main_process", "local_last", "all"] + | list[int] + | set[int] + | Callable[[int], bool] +) + + +class IRuntime(ABC): + @abstractmethod + def setup(self) -> None: ... + + @abstractmethod + def cleanup(self) -> None: ... + + @abstractmethod + def autocast(self) -> AbstractContextManager: ... + + @abstractmethod + def wait_for_everyone(self) -> None: ... + + @abstractmethod + def set_seed(self, seed: int, device_specific: bool = False) -> int: ... + + @abstractmethod + def prepare_models(self, models: PrepareModelsT) -> PrepareModelsT: ... + + @abstractmethod + def prepare_train_dataloader( + self, train_dataloader: PrepareDataLoaderT + ) -> PrepareDataLoaderT: ... + + @abstractmethod + def prepare_val_dataloader(self, val_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: ... + + @abstractmethod + def compile(self, model: CompileT) -> CompileT: ... + + @abstractmethod + def backward(self, loss: torch.Tensor) -> None: ... + + @abstractmethod + def clip_grad_norm_( + self, + parameters: Iterable[torch.Tensor] | torch.Tensor, + max_norm: float, + norm_type: float = 2, + ) -> torch.Tensor: ... + + @abstractmethod + def clip_grad_value_( + self, parameters: Iterable[torch.Tensor] | torch.Tensor, clip_value: float + ) -> None: ... + + @abstractmethod + def save_state(self, path: str | Path) -> None: ... + + @abstractmethod + def load_state(self, path: str | Path) -> None: ... + + @abstractmethod + def skip_first_batches(self, dataloader_iterator: Iterator, num_batches: int) -> None: ... + + @property + @abstractmethod + def sync_gradients(self) -> bool: ... + + @property + @abstractmethod + def device(self) -> torch.device: ... + + @property + @abstractmethod + def is_main_process(self) -> bool: ... + + @property + @abstractmethod + def is_local_main_process(self) -> bool: ... + + @property + @abstractmethod + def is_last_process(self) -> bool: ... + + @property + @abstractmethod + def is_local_last_process(self) -> bool: ... + + @property + @abstractmethod + def local_rank(self) -> int: ... + + @property + @abstractmethod + def global_rank(self) -> int: ... + + @property + @abstractmethod + def local_world_size(self) -> int: ... + + @property + @abstractmethod + def world_size(self) -> int: ... + + @property + @abstractmethod + def dtype(self) -> torch.dtype: ... + + def __enter__(self): + self.setup() + return self + + def __exit__(self, exc_type, exc_value, traceback): + # avoid barrier if exceution errored + if exc_type is None: + self.cleanup() + + # if exc_type is not None: + # raise exc_value + # Handle exceptions if necessary + # pass + + # def __del__(self): + # torch.distributed.barrier() + # torch.distributed.destroy_process_group() + + def check_filter(self, filter_: Filter): + return ( + filter_ == "all" + or (filter_ == "main_process" and self.is_main_process) + or (filter_ == "local_main_process" and self.is_local_main_process) + or (filter_ == "last" and self.is_last_process) + or (filter_ == "local_last" and self.is_local_last_process) + or (isinstance(filter_, (list, set)) and self.global_rank in filter_) + or (callable(filter_) and filter_(self.global_rank)) + ) + + def print( + self, *args, filter_: Filter = "main_process", rank_prefix=False, flush=True, **kwargs + ) -> None: + if not self.check_filter(filter_): + return + + if rank_prefix: + print(f"[global_rank={self.global_rank}]", *args, flush=flush, **kwargs) + else: + print(*args, flush=flush, **kwargs) + + def process_print( + self, *args, filter_: Filter = "all", rank_prefix=True, flush=True, **kwargs + ) -> None: + if not self.check_filter(filter_): + return + + if rank_prefix: + prefix = f"[global_rank={self.global_rank}]" + if len(args) == 1: # avoid out-of-order printing if possible + out = f"{prefix} {args[0]}" + args = (out,) + else: + args = (prefix, *args) + print(*args, flush=flush, **kwargs) + else: + print(*args, flush=flush, **kwargs) + + +class NativeDdpRuntime(IRuntime): + def __init__( + self, + dtype: torch.dtype = torch.float, + torch_distributed_timeout: timedelta | None = None, + ): + self._master_addr = os.environ["MASTER_ADDR"] + self._master_port = int(os.environ["MASTER_PORT"]) + self._local_rank = int(os.environ["LOCAL_RANK"]) + self._global_rank = int(os.environ["RANK"]) + self._local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + self._world_size = int(os.environ["WORLD_SIZE"]) + self._device = torch.device(self.local_rank) + self._dtype = dtype + self._torch_distributed_timeout = torch_distributed_timeout + + @override + def setup(self): + torch.cuda.set_device(self._device) + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + "cpu:gloo,cuda:nccl", timeout=self._torch_distributed_timeout + ) + input_tensors = [ + torch.tensor([0], dtype=torch.float32, device=self._device) + for _ in range(self.world_size) + ] + output_tensors = [ + torch.tensor([0], dtype=torch.float32, device=self._device) + for _ in range(self.world_size) + ] + torch.distributed.all_to_all(input_tensors, output_tensors) + + @override + def cleanup(self): + with suppress(Exception): + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + @override + def autocast(self) -> AbstractContextManager: + result = torch.autocast(device_type="cuda", dtype=self._dtype, enabled=True) + return result + + @override + def wait_for_everyone(self): + torch.distributed.barrier() + + @override + def set_seed(self, seed: int, device_specific: bool = False) -> int: + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + + Args: + seed (`int`): + The seed to set. + device_specific (`bool`, *optional*, defaults to `False`): + Whether to differ the seed on each device slightly with `self.process_index`. + """ + if device_specific: + seed += self.global_rank + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + return seed + + @override + def prepare_models(self, models: PrepareModelsT) -> PrepareModelsT: + assert all(isinstance(x, nn.Module) for x in models) + new_models = [nn.parallel.DistributedDataParallel(m) for m in models] + new_models = cast("PrepareModelsT", new_models) + return new_models # type: ignore[return-value] + + @override + def prepare_train_dataloader(self, train_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: + return train_dataloader + + @override + def prepare_val_dataloader(self, val_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: + return val_dataloader + + @override + def compile(self, model: CompileT) -> CompileT: + result = torch.compile(model) + result = cast("CompileT", result) + return result + + @override + def backward(self, loss: torch.Tensor) -> None: + loss.backward() + + @override + def clip_grad_norm_( + self, + parameters: Iterable[torch.Tensor] | torch.Tensor, + max_norm: float, + norm_type: float = 2, + ) -> torch.Tensor: + result = torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type) + return result + + @override + def clip_grad_value_( + self, parameters: Iterable[torch.Tensor] | torch.Tensor, clip_value: float + ) -> None: + torch.nn.utils.clip_grad_value_(parameters, clip_value) + + @override + def save_state(self, path: str | Path) -> None: + pass + + @override + def load_state(self, path: str | Path) -> None: + pass + + @override + def skip_first_batches(self, dataloader_iterator: Iterator, num_batches: int) -> None: + for _ in tqdm( + range(num_batches), desc=f"rank {self._global_rank}: skip_first_batches({num_batches=})" + ): + next(dataloader_iterator) + + @property + @override + def sync_gradients(self) -> bool: + return True + + @property + @override + def is_main_process(self) -> bool: + result = self.global_rank == 0 + return result + + @property + @override + def is_local_main_process(self) -> bool: + result = self.local_rank == 0 + return result + + @property + @override + def is_last_process(self) -> bool: + result = self.global_rank == self.world_size - 1 + return result + + @property + @override + def is_local_last_process(self) -> bool: + result = self.local_rank == self.local_world_size - 1 + return result + + @property + @override + def local_rank(self) -> int: + return self._local_rank + + @property + @override + def global_rank(self) -> int: + return self._global_rank + + @property + @override + def local_world_size(self) -> int: + return self._local_world_size + + @property + @override + def world_size(self) -> int: + return self._world_size + + @property + @override + def device(self) -> torch.device: + return self._device + + @property + @override + def dtype(self) -> torch.dtype: + return self._dtype + + @property + def master_addr(self) -> str: + return self._master_addr + + @property + def master_port(self) -> int: + return self._master_port + + +class BaseRuntime(IRuntime): + def __init__(self, dtype: torch.dtype = torch.float): + self._device = torch.device(self.local_rank) + self._dtype = dtype + + @override + def setup(self): + torch.cuda.set_device(self._device) + + @override + def cleanup(self): ... + + @override + def autocast(self) -> AbstractContextManager: + result = torch.autocast(device_type="cuda", dtype=self._dtype, enabled=True) + return result + + @override + def wait_for_everyone(self): ... + + @override + def set_seed(self, seed: int, device_specific: bool = False) -> int: + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + + Args: + seed (`int`): + The seed to set. + device_specific (`bool`, *optional*, defaults to `False`): + Whether to differ the seed on each device slightly with `self.process_index`. + """ + if device_specific: + seed += self.global_rank + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + return seed + + @override + def prepare_models(self, models: PrepareModelsT) -> PrepareModelsT: + assert all(isinstance(x, nn.Module) for x in models) + return models + + @override + def prepare_train_dataloader(self, train_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: + return train_dataloader + + @override + def prepare_val_dataloader(self, val_dataloader: PrepareDataLoaderT) -> PrepareDataLoaderT: + return val_dataloader + + @override + def compile(self, model: CompileT) -> CompileT: + result = torch.compile(model) + result = cast("CompileT", result) + return result + + @override + def backward(self, loss: torch.Tensor) -> None: + loss.backward() + + @override + def clip_grad_norm_( + self, + parameters: Iterable[torch.Tensor] | torch.Tensor, + max_norm: float, + norm_type: float = 2, + ) -> torch.Tensor: + result = torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type) + return result + + @override + def clip_grad_value_( + self, parameters: Iterable[torch.Tensor] | torch.Tensor, clip_value: float + ) -> None: + torch.nn.utils.clip_grad_value_(parameters, clip_value) + + @override + def save_state(self, path: str | Path) -> None: + pass + + @override + def load_state(self, path: str | Path) -> None: + pass + + @override + def skip_first_batches(self, dataloader_iterator: Iterator, num_batches: int) -> None: + for _ in tqdm( + range(num_batches), desc=f"rank {self.global_rank}: skip_first_batches({num_batches=})" + ): + next(dataloader_iterator) + + @property + @override + def sync_gradients(self) -> bool: + return True + + @property + @override + def is_main_process(self) -> bool: + result = self.global_rank == 0 + return result + + @property + @override + def is_local_main_process(self) -> bool: + result = self.local_rank == 0 + return result + + @property + @override + def is_last_process(self) -> bool: + result = self.global_rank == self.world_size - 1 + return result + + @property + @override + def is_local_last_process(self) -> bool: + result = self.local_rank == self.local_world_size - 1 + return result + + @property + @override + def local_rank(self) -> int: + return 0 + + @property + @override + def global_rank(self) -> int: + return 0 + + @property + @override + def local_world_size(self) -> int: + return 1 + + @property + @override + def world_size(self) -> int: + return 1 + + @property + @override + def device(self) -> torch.device: + return self._device + + @property + @override + def dtype(self) -> torch.dtype: + return self._dtype + + @property + def master_addr(self) -> str | None: + return None + + @property + def master_port(self) -> int | None: + return None diff --git a/tests/_test_utils/torch_dist/dist_utils.py b/tests/_test_utils/torch_dist/dist_utils.py index c7407b0188..f7160cf288 100644 --- a/tests/_test_utils/torch_dist/dist_utils.py +++ b/tests/_test_utils/torch_dist/dist_utils.py @@ -34,6 +34,11 @@ def init_process(rank, size, job=None, backend="gloo", port=None): """Initialize the distributed environment.""" os.environ["MASTER_ADDR"] = "localhost" + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(size) + os.environ["LOCAL_WORLD_SIZE"] = str(size) + os.environ["WANDB_DISABLED"] = "true" port = str(get_free_port()) if port is None else str(port) diff --git a/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B.yaml b/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B.yaml new file mode 100644 index 0000000000..1d8fac655f --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/Llama-3_1-8B.yaml @@ -0,0 +1,108 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to v0.4_mini + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + +scoring: + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + num_solutions: 1 + minimal_diversity: 2 + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 780_000 # 78_000 + + mip_constraints: + use_greedy_search: false + is_multi_layer_puzzle: true + metric_overrides: + constrain_search_func: + max_seconds_per_solution: 60 + +realize_model: + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 2 + micro_batch_size: 1 + dataset_path: ${dataset_path}/valid + seed: 42 + shuffle_seed: 444 + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/tests/experimental/torch/_compress/resources/configs/pruning/attn_pruning.yaml b/tests/experimental/torch/_compress/resources/configs/pruning/attn_pruning.yaml new file mode 100644 index 0000000000..01886607e4 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/pruning/attn_pruning.yaml @@ -0,0 +1,16 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory # IndependentKvHeadContributionHook implementation that consumes less memory + target_layer: "self_attn.o_proj" + layer_input_descriptors_path: + +# n_heads_in_group: 4 +# num_attention_heads: 32 # num query heads +# num_kv_heads: 32 / 4 = 8 # num_query_heads // n_heads_in_group +n_heads_in_group_list: [8, 16, 32] # num_kv_heads = [4, 2, 1] +gqa_init_mode: "PruneKVHeads" diff --git a/tests/experimental/torch/_compress/resources/configs/pruning/ffn_pruning.yaml b/tests/experimental/torch/_compress/resources/configs/pruning/ffn_pruning.yaml new file mode 100644 index 0000000000..f0c852eec9 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/pruning/ffn_pruning.yaml @@ -0,0 +1,12 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/ffn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: iterative + target_layer: "mlp.down_proj" + layer_input_descriptors_path: + +intermediate_size_list: [256] # teacher_intermediate_size is 14336 +mlp_init_mode: "PruneByActivationsLog" diff --git a/tests/experimental/torch/_compress/resources/configs/pruning/hidden_dim_pruning.yaml b/tests/experimental/torch/_compress/resources/configs/pruning/hidden_dim_pruning.yaml new file mode 100644 index 0000000000..407c835d8c --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/pruning/hidden_dim_pruning.yaml @@ -0,0 +1,15 @@ +defaults: + - pruning_defaults + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +activation_hooks_kwargs: + method: layer_norm_contribution + target_layer: "layernorm" + +# Hidden dimension pruning specific settings +hidden_size_list: [3072, 2048] # Target hidden sizes to prune to +hidden_size_init_mode: "PruneByChannelRanking" +mlp_init_mode: "Truncate" # TODO, make it work with CopyAsIs/FromTeacher +gqa_init_mode: "AverageKV" # TODO, make it work with CopyAsIs/FromTeacher +linear_init_mode: "FromTeacher" diff --git a/tests/experimental/torch/_compress/resources/configs/pruning/pruning_defaults.yaml b/tests/experimental/torch/_compress/resources/configs/pruning/pruning_defaults.yaml new file mode 100644 index 0000000000..0a5eafcfff --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/pruning/pruning_defaults.yaml @@ -0,0 +1,32 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 100 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_outpt_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml b/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml new file mode 100644 index 0000000000..046ff51f65 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/validate_model_defaults.yaml @@ -0,0 +1,15 @@ +block_size: 8192 +bos_rate: 0.5 +data_column: conversation +val_dataset_name: train +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:utils.data.dataloaders.load_from_disk_fn} diff --git a/tests/experimental/torch/_compress/resources/configs/validate_solutions_defaults.yaml b/tests/experimental/torch/_compress/resources/configs/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/configs/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/tests/experimental/torch/_compress/resources/tokenizer/special_tokens_map.json b/tests/experimental/torch/_compress/resources/tokenizer/special_tokens_map.json new file mode 100644 index 0000000000..02ee80b619 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/tokenizer/special_tokens_map.json @@ -0,0 +1,16 @@ +{ + "bos_token": { + "content": "<|begin_of_text|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "<|eot_id|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/tests/experimental/torch/_compress/resources/tokenizer/tokenizer.json b/tests/experimental/torch/_compress/resources/tokenizer/tokenizer.json new file mode 100644 index 0000000000..83592e2494 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/tokenizer/tokenizer.json @@ -0,0 +1,212 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [], + "normalizer": null, + "pre_tokenizer": { + "type": "Sequence", + "pretokenizers": [ + { + "type": "Split", + "pattern": { + "Regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + }, + "behavior": "Isolated", + "invert": false + }, + { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": false + } + ] + }, + "post_processor": { + "type": "Sequence", + "processors": [ + { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": false, + "use_regex": true + }, + { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "<|begin_of_text|>", + "type_id": 1 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + } + ], + "special_tokens": { + "<|begin_of_text|>": { + "id": "<|begin_of_text|>", + "ids": [ + 100 + ], + "tokens": [ + "<|begin_of_text|>" + ] + } + } + } + ] + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true, + "use_regex": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "byte_fallback": false, + "ignore_merges": true, + "vocab": { + "!": 0, + "\"": 1, + "#": 2, + "$": 3, + "%": 4, + "&": 5, + "'": 6, + "(": 7, + ")": 8, + "*": 9, + "+": 10, + ",": 11, + "-": 12, + ".": 13, + "/": 14, + "0": 15, + "1": 16, + "2": 17, + "3": 18, + "4": 19, + "5": 20, + "6": 21, + "7": 22, + "8": 23, + "9": 24, + ":": 25, + ";": 26, + "<": 27, + "=": 28, + ">": 29, + "?": 30, + "@": 31, + "A": 32, + "B": 33, + "C": 34, + "D": 35, + "E": 36, + "F": 37, + "G": 38, + "H": 39, + "I": 40, + "J": 41, + "K": 42, + "L": 43, + "M": 44, + "N": 45, + "O": 46, + "P": 47, + "Q": 48, + "R": 49, + "S": 50, + "T": 51, + "U": 52, + "V": 53, + "W": 54, + "X": 55, + "Y": 56, + "Z": 57, + "[": 58, + "\\": 59, + "]": 60, + "^": 61, + "_": 62, + "`": 63, + "a": 64, + "b": 65, + "c": 66, + "d": 67, + "e": 68, + "f": 69, + "g": 70, + "h": 71, + "i": 72, + "j": 73, + "k": 74, + "l": 75, + "m": 76, + "n": 77, + "o": 78, + "p": 79, + "q": 80, + "r": 81, + "s": 82, + "t": 83, + "u": 84, + "v": 85, + "w": 86, + "x": 87, + "y": 88, + "z": 89, + "{": 90, + "|": 91, + "}": 92, + "~": 93, + "¡": 94, + "¢": 95, + "£": 96, + "¤": 97, + "¥": 98, + "¦": 99, + "<|begin_of_text|>": 100, + "<|eot_id|>": 101 + }, + "merges": [] + } +} diff --git a/tests/experimental/torch/_compress/resources/tokenizer/tokenizer_config.json b/tests/experimental/torch/_compress/resources/tokenizer/tokenizer_config.json new file mode 100644 index 0000000000..754d9e8db5 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/tokenizer/tokenizer_config.json @@ -0,0 +1,13 @@ +{ + "bos_token": "<|begin_of_text|>", + "chat_template": "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n", + "clean_up_tokenization_spaces": true, + "eos_token": "<|eot_id|>", + "extra_special_tokens": {}, + "model_input_names": [ + "input_ids", + "attention_mask" + ], + "model_max_length": 131072, + "tokenizer_class": "PreTrainedTokenizer" +} diff --git a/tests/experimental/torch/_compress/resources/tokenizer/truncate_tokenizer.py b/tests/experimental/torch/_compress/resources/tokenizer/truncate_tokenizer.py new file mode 100644 index 0000000000..aedcae4ab2 --- /dev/null +++ b/tests/experimental/torch/_compress/resources/tokenizer/truncate_tokenizer.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script was used to truncate the tokenizer.json file from Llama 3.1 8B model +to keep only the top 100 most common tokens. +""" + +import json + +# Path to your original and new tokenizer.json +in_path = "./tokenizer.json" +out_path = "./tokenizer_truncated.json" + +# How many top tokens to keep +NUM_TO_KEEP = 100 + +with open(in_path, encoding="utf-8") as f: + tokenizer_data = json.load(f) + +# Get and sort the original vocab by index (frequency proxy) +orig_vocab = tokenizer_data["model"]["vocab"] + +# Sort tokens by their original index (lowest index = assumed most common/important) +sorted_tokens = sorted(orig_vocab.items(), key=lambda item: item[1]) + +# Keep the top N tokens +tokens_to_keep = [tok for tok, idx in sorted_tokens[:NUM_TO_KEEP]] + +# Re-index the selected tokens: 0..N-1 +small_vocab = {tok: i for i, tok in enumerate(tokens_to_keep)} +tokenizer_data["model"]["vocab"] = small_vocab + +# Update vocab size +if "vocab_size" in tokenizer_data["model"]: + tokenizer_data["model"]["vocab_size"] = len(small_vocab) + +# Optionally remove merges if present and unneeded (mostly for BPE/WordPiece) +if "merges" in tokenizer_data["model"]: + tokenizer_data["model"]["merges"] = [] + +# Remove added_tokens if not needed +if "added_tokens" in tokenizer_data: + tokenizer_data["added_tokens"] = [] + +# Write out the truncated tokenizer.json +with open(out_path, "w", encoding="utf-8") as f: + json.dump(tokenizer_data, f, indent=2, ensure_ascii=False) + +print(f"Truncated tokenizer saved to: {out_path}") diff --git a/tests/experimental/torch/_compress/test_compress.py b/tests/experimental/torch/_compress/test_compress.py new file mode 100644 index 0000000000..096de4de3c --- /dev/null +++ b/tests/experimental/torch/_compress/test_compress.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import os +import shutil +from functools import partial +from pathlib import Path + +import pytest +import torch +from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job +from datasets import Dataset, DatasetDict +from puzzle_tools.hydra_utils import register_hydra_resolvers +from scripts.convert_llama3_to_decilm import convert_llama3_to_decilm +from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, PreTrainedTokenizerBase + +from modelopt.torch._compress import compress +from modelopt.torch._compress.runtime import NativeDdpRuntime + + +@pytest.fixture +def project_root_path(request: pytest.FixtureRequest) -> Path: + return Path(request.config.rootpath) + + +# The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) +# using a one-click command. +# +# Note: Bypass is disabled now in the test. + +# How to run this test (currently only supported internally at Nvidia). +# +# Have both modelopt and puzzle source code in the same directory: +# /workspace/modelopt +# /workspace/puzzletron +# +# submit_job --partition interactive --time 0 \ +# --image gitlab-master.nvidia.com/deci/puzzletron:trtllm_main \ +# --workdir $MODELOPT SRC DIRECTORY --interactive --gpu 1 +# +# pip install mip +# pip install lru-dict +# +# export PYTHONPATH=$PYTHONPATH:/workspace/puzzletron/v1 +# +# pytest -s -v ./tests/experimental/torch/_compress/test_compress.py::test_compress -o addopts="" + + +def test_compress(project_root_path: Path, tmp_path: Path): + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_compress_multiprocess_job, project_root_path, tmp_path), + backend="nccl", + ) + + +def _test_compress_multiprocess_job(project_root_path: Path, tmp_path: Path, rank: int, size: int): + register_hydra_resolvers() + + puzzle_dir = tmp_path + dataset_path = puzzle_dir / "dummy_dataset" + hydra_config_dir = project_root_path / "tests/experimental/torch/_compress/resources/configs" + + _runtime = NativeDdpRuntime( + dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) + ) + + with _runtime as runtime: + # + # Test setup + # + if rank == 0: + # Setup puzzle_dir and dataset + _setup_puzzle_dir(puzzle_dir) + _save_dummy_dataset(dataset_path) + + # + # Step 1: Create and save a teacher model to compress + # This mimics the normal pipeline where we start with a Llama model + # + tokenizer_path = ( + project_root_path / "tests/experimental/torch/_compress/resources/tokenizer" + ) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + + # Create a small Llama model (not DeciLM) to match the normal conversion pipeline + hf_ckpt_teacher_dir = "ckpts/teacher" + llama_checkpoint_path = puzzle_dir / hf_ckpt_teacher_dir + _create_and_save_small_llama_model( + llama_checkpoint_path, vocab_size=tokenizer.vocab_size, tokenizer=tokenizer + ) + + # Use the full conversion pipeline (matches normal usage) + convert_llama3_to_decilm( + input_dir=llama_checkpoint_path, + output_dir=llama_checkpoint_path, + ) + runtime.wait_for_everyone() + + # Compress the model using a one-click approach + compress.compress( + str(hydra_config_dir), "Llama-3_1-8B", str(puzzle_dir), str(dataset_path), runtime + ) + + # + # Check assertions + # + if rank == 0: + # assertions for the score_pruning_activations step 1 + rank = int(os.environ["RANK"]) + rank_filepath = ( + f"pruning/pruning_scores/ffn_iterative/100samples_diverse_mini/rank_{rank}.pth" + ) + assert (puzzle_dir / rank_filepath).is_file() + + # assertions for the pruning_ckpts step 2 + assert (puzzle_dir / "ckpts/ffn_256_attn_no_op").exists() + + # assertions for the build_library_and_stats step 4 + + assert (puzzle_dir / "replacement_library.json").is_file() + assert (puzzle_dir / "subblock_stats.json").is_file() + + # assertions for the scoring step 5 + solution_0_filepath = ( + puzzle_dir / "single_sequence_replacement_solutions--validation/solution_0.json" + ) + + assert solution_0_filepath.exists() + + # assertions for the mip_and_realize_models step 6 + solution_0_ckpt_config_path = ( + puzzle_dir + / "mip/puzzle_solutions/target_memory_780000MiB/solutions--checkpoints/solution_0/config.json" + ) + + assert solution_0_ckpt_config_path.exists() + assert ( + puzzle_dir / "mip/puzzle_solutions/target_memory_780000MiB/solutions.json" + ).exists() + + runtime.wait_for_everyone() + + print("PYTEST SUMMARY: test_compress_model() test has finished successfully") + + +def _create_and_save_small_llama_model( + output_path: str, vocab_size: int, tokenizer: PreTrainedTokenizerBase +): + """ + Create and save a small Llama model for testing the conversion pipeline. + This mimics having a real Llama checkpoint that needs to be converted. + """ + os.makedirs(output_path, exist_ok=True) + + # Create a minimal Llama config (small for testing) + # Note: intermediate_size must be divisible by 256 per DeciLM config requirements + # Note: hidden_size must give head_dim >= 8 for Flash Attention 2 compatibility + llama_config = LlamaConfig( + vocab_size=vocab_size, + hidden_size=256, # 32 heads times 8 head_dim = 256 (matches bypass config expectations) + intermediate_size=512, # Must be divisible by 256 + num_hidden_layers=2, + num_attention_heads=32, # Matches original test + num_key_value_heads=8, # GQA: 32÷4=8 (matches original n_heads_in_group=4) + max_position_embeddings=512, + rms_norm_eps=1e-5, + rope_theta=10000.0, + attention_bias=False, + hidden_act="silu", + tie_word_embeddings=False, + ) + + # Create and save the Llama model + model = LlamaForCausalLM(llama_config) + model.to(dtype=torch.bfloat16).save_pretrained(output_path) + + # Save tokenizer + tokenizer.save_pretrained(output_path) + + # Save config + llama_config.save_pretrained(output_path) + + +def _setup_puzzle_dir(puzzle_dir: str): + if Path(puzzle_dir).exists(): + shutil.rmtree(puzzle_dir) + Path(puzzle_dir).mkdir(parents=True, exist_ok=True) + + +def _save_dummy_dataset(dataset_path: str): + # dummy sample + sample = [ + {"role": "user", "content": "please cite Lorem Ipsum?"}, + { + "role": "assistant", + "content": ( + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed in blandit ante. " + "Sed tempus erat urna, ac elementum nisl facilisis quis. Aliquam consectetur mollis massa, " + "in elementum sem venenatis posuere. Fusce lorem arcu, egestas vel massa sollicitudin, " + "dictum mollis purus. Proin in ullamcorper elit. Nam tellus nisi, volutpat a mattis vel, " + "pretium in purus. Nunc at lectus facilisis risus scelerisque rhoncus eu nec ex. " + "Maecenas semper, tellus non placerat vulputate, urna felis facilisis diam, " + "sit amet vestibulum erat sapien nec libero. Praesent non massa velit. Donec faucibus mi eros. " + "Nam turpis nulla, congue sit amet mi at, porttitor scelerisque elit. Nunc id sodales lorem, " + "nec tincidunt leo. Quisque a neque nec ligula porttitor auctor. " + "Nunc accumsan nunc ac tellus congue vehicula. Praesent tellus eros, luctus non gravida dapibus, " + "faucibus eu ex. Quisque bibendum leo pharetra, tristique est vitae, hendrerit nunc. " + "Duis nec congue dolor. Donec commodo ipsum non efficitur volutpat. " + "Nulla risus nulla, efficitur et urna at, imperdiet sodales lorem. " + "Suspendisse erat est, sollicitudin at nisl tincidunt, vehicula hendrerit lectus. " + "Nam quis nisi ullamcorper, rhoncus massa vel, tempus purus. " + "Duis pulvinar eros vel nulla pellentesque, at dapibus justo laoreet. " + "Praesent tortor orci, vulputate fermentum dapibus nec, feugiat vitae tortor. " + "Donec mollis convallis massa quis iaculis." + ), + }, + ] + + # Prepare train and val splits with sample repeated, 2500 samples are for + # 128 samples with block-size 8192 and LLama3 tokenizer + data = [{"conversation": sample}] * 2500 + + # For train-val splits + data_dict = DatasetDict({"train": Dataset.from_list(data), "valid": Dataset.from_list(data)}) + data_dict.save_to_disk(dataset_path)