Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 104 additions & 4 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@
)
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.llm_metadata import LLM_METADATAS
from astrbot.core.utils.media_utils import (
IMAGE_COMPRESS_DEFAULT_MAX_SIZE,
IMAGE_COMPRESS_DEFAULT_QUALITY,
compress_image,
)
from astrbot.core.utils.quoted_message.settings import (
SETTINGS as DEFAULT_QUOTED_MESSAGE_SETTINGS,
)
Expand Down Expand Up @@ -473,16 +478,23 @@ async def _request_img_caption(


async def _ensure_img_caption(
event: AstrMessageEvent,
req: ProviderRequest,
cfg: dict,
plugin_context: Context,
image_caption_provider: str,
) -> None:
try:
compressed_urls = []
for url in req.image_urls:
compressed_url = await _compress_image_for_provider(url, cfg)
Comment on lines +488 to +490
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Image compression in _ensure_img_caption ignores user provider_settings due to wrong argument type.

_ensure_img_caption passes cfg into _compress_image_for_provider, but that function expects dict[str, object] | None, and _get_image_compress_args only reads settings when given a dict. Because cfg is a MainAgentBuildConfig, this path always falls back to defaults and ignores provider_settings.image_compress_enabled and image_compress_options. To match other call sites (_process_quote_message, build_main_agent) and honor user configuration, this should pass cfg.provider_settings instead.

compressed_urls.append(compressed_url)
Comment thread
Soulter marked this conversation as resolved.
if _is_generated_compressed_image_path(url, compressed_url):
event.track_temporary_local_file(compressed_url)
caption = await _request_img_caption(
image_caption_provider,
cfg,
req.image_urls,
compressed_urls,
plugin_context,
)
if caption:
Expand All @@ -492,6 +504,9 @@ async def _ensure_img_caption(
req.image_urls = []
except Exception as exc: # noqa: BLE001
logger.error("处理图片描述失败: %s", exc)
req.extra_user_content_parts.append(TextPart(text="[Image Captioning Failed]"))
finally:
req.image_urls = []
Comment thread
Soulter marked this conversation as resolved.


def _append_quoted_image_attachment(req: ProviderRequest, image_path: str) -> None:
Expand All @@ -511,12 +526,64 @@ def _get_quoted_message_parser_settings(
return DEFAULT_QUOTED_MESSAGE_SETTINGS.with_overrides(overrides)


def _get_image_compress_args(
provider_settings: dict[str, object] | None,
) -> tuple[bool, int, int]:
if not isinstance(provider_settings, dict):
return True, IMAGE_COMPRESS_DEFAULT_MAX_SIZE, IMAGE_COMPRESS_DEFAULT_QUALITY

enabled = provider_settings.get("image_compress_enabled", True)
if not isinstance(enabled, bool):
enabled = True

raw_options = provider_settings.get("image_compress_options", {})
options = raw_options if isinstance(raw_options, dict) else {}

max_size = options.get("max_size", IMAGE_COMPRESS_DEFAULT_MAX_SIZE)
if not isinstance(max_size, int):
max_size = IMAGE_COMPRESS_DEFAULT_MAX_SIZE
max_size = max(max_size, 1)

quality = options.get("quality", IMAGE_COMPRESS_DEFAULT_QUALITY)
if not isinstance(quality, int):
quality = IMAGE_COMPRESS_DEFAULT_QUALITY
quality = min(max(quality, 1), 100)

return enabled, max_size, quality


async def _compress_image_for_provider(
url_or_path: str,
provider_settings: dict[str, object] | None,
) -> str:
try:
enabled, max_size, quality = _get_image_compress_args(provider_settings)
if not enabled:
return url_or_path
return await compress_image(url_or_path, max_size=max_size, quality=quality)
except Exception as exc: # noqa: BLE001
logger.error("Image compression failed: %s", exc)
return url_or_path


def _is_generated_compressed_image_path(
original_path: str,
compressed_path: str | None,
) -> bool:
if not compressed_path or compressed_path == original_path:
return False
if compressed_path.startswith("http") or compressed_path.startswith("data:image"):
return False
return os.path.exists(compressed_path)


async def _process_quote_message(
event: AstrMessageEvent,
req: ProviderRequest,
img_cap_prov_id: str,
plugin_context: Context,
quoted_message_settings: QuotedMessageParserSettings = DEFAULT_QUOTED_MESSAGE_SETTINGS,
config: MainAgentBuildConfig | None = None,
) -> None:
quote = None
for comp in event.message_obj.message:
Expand Down Expand Up @@ -549,15 +616,24 @@ async def _process_quote_message(
if image_seg:
try:
prov = None
path = None
compress_path = None
if img_cap_prov_id:
prov = plugin_context.get_provider_by_id(img_cap_prov_id)
if prov is None:
prov = plugin_context.get_using_provider(event.unified_msg_origin)

if prov and isinstance(prov, Provider):
path = await image_seg.convert_to_file_path()
compress_path = await _compress_image_for_provider(
path,
config.provider_settings if config else None,
)
if path and _is_generated_compressed_image_path(path, compress_path):
event.track_temporary_local_file(compress_path)
llm_resp = await prov.text_chat(
prompt="Please describe the image content.",
image_urls=[await image_seg.convert_to_file_path()],
image_urls=[compress_path],
)
if llm_resp.completion_text:
content_parts.append(
Expand All @@ -567,6 +643,16 @@ async def _process_quote_message(
logger.warning("No provider found for image captioning in quote.")
except BaseException as exc:
logger.error("处理引用图片失败: %s", exc)
finally:
if (
compress_path
and compress_path != path
and os.path.exists(compress_path)
):
try:
os.remove(compress_path)
except Exception as exc: # noqa: BLE001
logger.warning("Fail to remove temporary compressed image: %s", exc)

quoted_content = "\n".join(content_parts)
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
Expand Down Expand Up @@ -635,6 +721,7 @@ async def _decorate_llm_request(
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
if img_cap_prov_id and req.image_urls:
await _ensure_img_caption(
event,
req,
cfg,
plugin_context,
Expand All @@ -649,6 +736,7 @@ async def _decorate_llm_request(
img_cap_prov_id,
plugin_context,
quoted_message_settings,
config,
)

tz = config.timezone
Expand Down Expand Up @@ -1025,7 +1113,13 @@ async def build_main_agent(
# media files attachments
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
path = await comp.convert_to_file_path()
image_path = await _compress_image_for_provider(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
path,
config.provider_settings,
)
Comment thread
Soulter marked this conversation as resolved.
if _is_generated_compressed_image_path(path, image_path):
event.track_temporary_local_file(image_path)
req.image_urls.append(image_path)
req.extra_user_content_parts.append(
TextPart(text=f"[Image Attachment: path {image_path}]")
Expand All @@ -1052,7 +1146,13 @@ async def build_main_agent(
for reply_comp in comp.chain:
if isinstance(reply_comp, Image):
has_embedded_image = True
image_path = await reply_comp.convert_to_file_path()
path = await reply_comp.convert_to_file_path()
image_path = await _compress_image_for_provider(
path,
config.provider_settings,
)
if _is_generated_compressed_image_path(path, image_path):
event.track_temporary_local_file(image_path)
req.image_urls.append(image_path)
_append_quoted_image_attachment(req, image_path)
elif isinstance(reply_comp, File):
Expand Down
28 changes: 28 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@
"shipyard_neo_profile": "python-default",
"shipyard_neo_ttl": 3600,
},
"image_compress_enabled": True,
"image_compress_options": {
"max_size": 1280,
"quality": 95,
},
},
# SubAgent orchestrator mode:
# - main_enable = False: disabled; main LLM mounts tools normally (persona selection).
Expand Down Expand Up @@ -3452,6 +3457,29 @@ class ChatProviderTemplate(TypedDict):
"type": "string",
"hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。",
},
"provider_settings.image_compress_enabled": {
"description": "启用图片压缩",
"type": "bool",
"hint": "启用后,发送给多模态模型前会先压缩本地大图片。",
},
"provider_settings.image_compress_options.max_size": {
"description": "最大边长",
"type": "int",
"hint": "压缩后图片的最长边,单位为像素。超过该尺寸时会按比例缩放。",
"condition": {
"provider_settings.image_compress_enabled": True,
},
"slider": {"min": 256, "max": 4096, "step": 64},
},
"provider_settings.image_compress_options.quality": {
"description": "压缩质量",
"type": "int",
"hint": "JPEG 输出质量,范围为 1-100。值越高,画质越好,文件也越大。",
"condition": {
"provider_settings.image_compress_enabled": True,
},
"slider": {"min": 1, "max": 100, "step": 1},
},
"provider_tts_settings.dual_output": {
"description": "开启 TTS 时同时输出语音和文字内容",
"type": "bool",
Expand Down
1 change: 1 addition & 0 deletions astrbot/core/pipeline/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,5 @@ async def execute(self, event: AstrMessageEvent) -> None:

logger.debug("pipeline 执行完毕。")
finally:
event.cleanup_temporary_local_files()
active_event_registry.unregister(event)
21 changes: 21 additions & 0 deletions astrbot/core/platform/astr_message_event.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import asyncio
import hashlib
import os
import re
import uuid
from collections.abc import AsyncGenerator
Expand Down Expand Up @@ -88,6 +89,8 @@ def __init__(
"""在此次事件中是否有过至少一次发送消息的操作"""
self.call_llm = False
"""是否在此消息事件中禁止默认的 LLM 请求"""
self._temporary_local_files: list[str] = []
"""Temporary local files created during this event and safe to delete when it finishes."""

self.plugins_name: list[str] | None = None
"""该事件启用的插件名称列表。None 表示所有插件都启用。空列表表示没有启用任何插件。"""
Expand Down Expand Up @@ -228,6 +231,24 @@ def clear_extra(self) -> None:
logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}")
self._extras.clear()

def track_temporary_local_file(self, path: str) -> None:
if path and path not in self._temporary_local_files:
self._temporary_local_files.append(path)

def cleanup_temporary_local_files(self) -> None:
paths = list(self._temporary_local_files)
self._temporary_local_files.clear()
for path in paths:
try:
if os.path.exists(path):
os.remove(path)
except OSError as e:
logger.warning(
"Failed to remove temporary local file %s: %s",
path,
e,
)

def is_private_chat(self) -> bool:
"""是否是私聊。"""
return self.get_message_type() == MessageType.FRIEND_MESSAGE
Expand Down
Loading
Loading