Skip to content
Merged
123 changes: 122 additions & 1 deletion astrbot/core/utils/pip_installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import importlib.util
import io
import logging
import ntpath
import os
import re
import shlex
import sys
import threading
from collections import deque
from collections.abc import Mapping
from dataclasses import dataclass
from urllib.parse import urlparse

Expand All @@ -30,6 +32,9 @@

_DISTLIB_FINDER_PATCH_ATTEMPTED = False
_SITE_PACKAGES_IMPORT_LOCK = threading.RLock()
_PIP_IN_PROCESS_ENV_LOCK = threading.RLock()
_WINDOWS_UNC_PATH_PREFIXES = ("\\\\?\\UNC\\", "\\??\\UNC\\")
_WINDOWS_EXTENDED_PATH_PREFIXES = ("\\\\?\\", "\\??\\")
_PIP_FAILURE_PATTERNS = {
"error_prefix": re.compile(r"^\s*error:", re.IGNORECASE),
"user_requested": re.compile(r"\bthe user requested\b", re.IGNORECASE),
Expand Down Expand Up @@ -235,6 +240,120 @@ def _run_pip_main_streaming(pip_main, args: list[str]) -> tuple[int, list[str]]:
return result_code, stream.lines


@contextlib.contextmanager
def _temporary_environ(updates: Mapping[str, str]):
if not updates:
yield
return

missing = object()
previous_values = {key: os.environ.get(key, missing) for key in updates}

try:
os.environ.update(updates)
yield
finally:
for key, previous_value in previous_values.items():
if previous_value is missing:
os.environ.pop(key, None)
else:
assert isinstance(previous_value, str)
os.environ[key] = previous_value


def _run_pip_main_with_temporary_environ(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
pip_main,
args: list[str],
) -> tuple[int, list[str]]:
# os.environ is process-wide; serialize reading current INCLUDE/LIB values
# together with the temporary mutation window around the in-process pip
# invocation.
with _PIP_IN_PROCESS_ENV_LOCK:
env_updates = _build_packaged_windows_runtime_build_env(base_env=os.environ)
if not env_updates:
return _run_pip_main_streaming(pip_main, args)

with _temporary_environ(env_updates):
return _run_pip_main_streaming(pip_main, args)


def _normalize_windows_native_build_path(path: str) -> str:
"""Normalize a Windows path returned by native APIs or sys.executable.

Extended UNC prefixes are converted back to the standard ``\\server`` form,
other extended prefixes are stripped, and the remaining path is normalized.
"""
normalized = path.replace("/", "\\")

# Extended UNC: \\?\UNC\server\share\... -> \\server\share\...
for prefix in _WINDOWS_UNC_PATH_PREFIXES:
if normalized.startswith(prefix):
return ntpath.normpath(f"\\\\{normalized[len(prefix) :]}")

# Other extended prefixes are stripped before normalizing the path.
for prefix in _WINDOWS_EXTENDED_PATH_PREFIXES:
if normalized.startswith(prefix):
normalized = normalized[len(prefix) :]
break

return ntpath.normpath(normalized)


def _get_case_insensitive_env_value(
env: Mapping[str, str],
upper_to_key: Mapping[str, str],
name: str,
) -> str | None:
direct = env.get(name)
if direct is not None:
return direct

existing_key = upper_to_key.get(name.upper())
if existing_key is not None:
return env.get(existing_key)

return None


def _build_packaged_windows_runtime_build_env(
*,
base_env: Mapping[str, str] | None = None,
) -> dict[str, str]:
if sys.platform != "win32" or not is_packaged_desktop_runtime():
return {}

base_env = os.environ if base_env is None else base_env

runtime_executable = _normalize_windows_native_build_path(sys.executable)
runtime_dir = ntpath.dirname(runtime_executable)
if not runtime_dir:
return {}

include_dir = _normalize_windows_native_build_path(
ntpath.join(runtime_dir, "include")
)
libs_dir = _normalize_windows_native_build_path(ntpath.join(runtime_dir, "libs"))
include_exists = os.path.isdir(include_dir)
libs_exists = os.path.isdir(libs_dir)

if not (include_exists or libs_exists):
return {}

upper_to_key = {key.upper(): key for key in base_env}
env_updates: dict[str, str] = {}

if include_exists:
existing = _get_case_insensitive_env_value(base_env, upper_to_key, "INCLUDE")
env_updates["INCLUDE"] = (
f"{include_dir};{existing}" if existing else include_dir
)
if libs_exists:
existing = _get_case_insensitive_env_value(base_env, upper_to_key, "LIB")
env_updates["LIB"] = f"{libs_dir};{existing}" if existing else libs_dir

return env_updates


def _matches_pip_failure_pattern(line: str, *pattern_names: str) -> bool:
names = pattern_names or tuple(_PIP_FAILURE_PATTERNS)
return any(_PIP_FAILURE_PATTERNS[name].search(line) for name in names)
Expand Down Expand Up @@ -931,7 +1050,9 @@ async def _run_pip_in_process(self, args: list[str]) -> int:
original_handlers = list(logging.getLogger().handlers)
try:
result_code, output_lines = await asyncio.to_thread(
_run_pip_main_streaming, pip_main, args
_run_pip_main_with_temporary_environ,
pip_main,
args,
)
finally:
_cleanup_added_root_handlers(original_handlers)
Expand Down
Loading
Loading