diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index 9d1ae6b0..ba277633 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -286,10 +286,14 @@ def load_from_json(self, path: str | Path) -> None: saved_values = {k: v for k, v in config_dict.items() if k in supported_hparam_names} # Seed with the defaults, then overlay the saved values - default_values = dict(SMASH_SPACE.get_default_configuration()) - default_values.update(saved_values) - - self._configuration = Configuration(SMASH_SPACE, values=default_values) + self._configuration = SMASH_SPACE.get_default_configuration() + # activate all algorithms already in the space to make children hyperparameters appear + saved_algorithm_keys = set(self._configuration.keys()) & set(saved_values.keys()) + for key in saved_algorithm_keys: + self._configuration[key] = True + # register all saved hyperparameters + for key, value in saved_values.items(): + self._configuration[key] = value tokenizer_path = Path(path) / TOKENIZER_SAVE_PATH if tokenizer_path.exists():