diff --git a/install.py b/install.py index b02753d..105ad56 100644 --- a/install.py +++ b/install.py @@ -1,11 +1,14 @@ import launch import sys -from importlib_metadata import version python = sys.executable def install(): + if not launch.is_installed("importlib_metadata"): + launch.run_pip("install importlib_metadata", "importlib_metadata", live=True) + from importlib_metadata import version + if launch.is_installed("tensorrt"): if not version("tensorrt") == "9.0.1.post11.dev4": launch.run( diff --git a/scripts/trt.py b/scripts/trt.py index b9f16aa..1c16b74 100644 --- a/scripts/trt.py +++ b/scripts/trt.py @@ -299,7 +299,7 @@ def process_batch(self, p, *args, **kwargs): if self.torch_unet: return super().process_batch(p, *args, **kwargs) - if self.idx != sd_unet.current_unet.profile_idx: + if sd_unet.current_unet is not None and self.idx != sd_unet.current_unet.profile_idx: sd_unet.current_unet.profile_idx = self.idx sd_unet.current_unet.switch_engine() diff --git a/ui_trt.py b/ui_trt.py index 4f19847..6075084 100644 --- a/ui_trt.py +++ b/ui_trt.py @@ -268,7 +268,10 @@ def get_lora_checkpoints(): if os.path.exists(config_file): with open(config_file, "r") as f: config = json.load(f) - version = SDVersion.from_str(config["sd version"]) + try: + version = SDVersion.from_str(config["sd version"]) + except: + version = SDVersion.Unknown else: version = SDVersion.Unknown