Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion astrbot/core/provider/sources/whisper_api_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file
from astrbot.core.utils.media_utils import convert_audio_to_wav
from astrbot.core.utils.tencent_record_helper import (
convert_to_pcm_wav,
tencent_silk_to_wav,
Expand Down Expand Up @@ -76,7 +77,22 @@ async def get_text(self, audio_url: str) -> str:
if not os.path.exists(audio_url):
raise FileNotFoundError(f"文件不存在: {audio_url}")

if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
lower_audio_url = audio_url.lower()

if lower_audio_url.endswith(".opus"):
temp_dir = get_astrbot_temp_path()
output_path = os.path.join(
temp_dir,
f"whisper_api_{uuid.uuid4().hex[:8]}.wav",
)
logger.info("Converting opus file to wav using convert_audio_to_wav...")
await convert_audio_to_wav(audio_url, output_path)
audio_url = output_path
elif (
lower_audio_url.endswith(".amr")
or lower_audio_url.endswith(".silk")
or is_tencent
):
file_format = await self._get_audio_format(audio_url)

# 判断是否需要转换
Expand Down
72 changes: 72 additions & 0 deletions tests/test_whisper_api_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock

import pytest

from astrbot.core.provider.sources.whisper_api_source import ProviderOpenAIWhisperAPI


def _make_provider() -> ProviderOpenAIWhisperAPI:
provider = ProviderOpenAIWhisperAPI(
provider_config={
"id": "test-whisper-api",
"type": "openai_whisper_api",
"model": "whisper-1",
"api_key": "test-key",
},
provider_settings={},
)
provider.client = SimpleNamespace(
audio=SimpleNamespace(
transcriptions=SimpleNamespace(
create=AsyncMock(return_value=SimpleNamespace(text="transcribed text"))
)
),
close=AsyncMock(),
)
return provider


@pytest.mark.asyncio
async def test_get_text_converts_opus_files_to_wav_before_transcription(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
):
provider = _make_provider()
opus_path = tmp_path / "voice.opus"
opus_path.write_bytes(b"fake opus data")

conversions: list[tuple[str, str]] = []

async def fake_convert_audio_to_wav(audio_path: str, output_path: str | None = None):
assert output_path is not None
conversions.append((audio_path, output_path))
Path(output_path).write_bytes(b"fake wav data")
return output_path

monkeypatch.setattr(
"astrbot.core.provider.sources.whisper_api_source.get_astrbot_temp_path",
lambda: str(tmp_path),
)
monkeypatch.setattr(
"astrbot.core.provider.sources.whisper_api_source.convert_audio_to_wav",
fake_convert_audio_to_wav,
)

try:
result = await provider.get_text(str(opus_path))

assert result == "transcribed text"
assert conversions and conversions[0][0] == str(opus_path)
converted_path = Path(conversions[0][1])
assert converted_path.suffix == ".wav"
assert not converted_path.exists()

create_mock = provider.client.audio.transcriptions.create
create_mock.assert_awaited_once()
file_arg = create_mock.await_args.kwargs["file"]
assert file_arg[0] == "audio.wav"
assert file_arg[1].name.endswith(".wav")
file_arg[1].close()
finally:
await provider.terminate()
Loading