diff --git a/CLAUDE.md b/CLAUDE.md index 8ee8193c..aadafd4c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -122,7 +122,7 @@ npm run web ```bash # ASR Services cd extras/asr-services -docker compose up parakeet # Offline ASR with Parakeet +docker compose up parakeet-asr # Offline ASR with Parakeet # Speaker Recognition (with tests) cd extras/speaker-recognition @@ -136,13 +136,6 @@ docker compose up --build ## Architecture Overview -### Core Structure -- **backends/advanced-backend/**: Primary FastAPI backend with real-time audio processing - - `src/main.py`: Central FastAPI application with WebSocket audio streaming - - `src/auth.py`: Email-based authentication with JWT tokens - - `src/memory/`: LLM-powered conversation memory system using mem0 - - `webui/`: React-based web dashboard for conversation and user management - ### Key Components - **Audio Pipeline**: Real-time Opus/PCM โ†’ Application-level processing โ†’ Deepgram/Mistral transcription โ†’ memory extraction - **Wyoming Protocol**: WebSocket communication uses Wyoming protocol (JSONL + binary) for structured audio sessions @@ -1214,12 +1207,6 @@ curl http://[gpu-machine-ip]:8085/health # Speaker recognition ### Troubleshooting Distributed Setup -**Common Issues:** -- **CORS errors**: Tailscale IPs are automatically supported, but verify CORS_ORIGINS if using custom IPs -- **Service discovery**: Use `tailscale ip` to find machine IPs -- **Port conflicts**: Ensure services use different ports on shared machines -- **Authentication**: Services must be accessible without authentication for inter-service communication - **Debugging Commands:** ```bash # Check Tailscale connectivity diff --git a/backends/advanced/src/advanced_omi_backend/audio_utils.py b/backends/advanced/src/advanced_omi_backend/audio_utils.py index 2821d126..1a3937c7 100644 --- a/backends/advanced/src/advanced_omi_backend/audio_utils.py +++ b/backends/advanced/src/advanced_omi_backend/audio_utils.py @@ -6,6 +6,10 @@ import logging import os import time +import wave +import io +import numpy as np +from pathlib import Path # Type import to avoid circular imports from typing import TYPE_CHECKING, Optional @@ -88,6 +92,69 @@ async def process_audio_chunk( client_state.update_audio_received(chunk) +async def load_audio_file_as_chunk(audio_path: Path) -> AudioChunk: + """Load existing audio file into Wyoming AudioChunk format for reprocessing. + + Args: + audio_path: Path to the audio file on disk + + Returns: + AudioChunk object ready for processing + + Raises: + FileNotFoundError: If audio file doesn't exist + ValueError: If audio file format is invalid + """ + try: + # Read the audio file + with open(audio_path, 'rb') as f: + file_content = f.read() + + # Process WAV file using existing pattern from system_controller.py + with wave.open(io.BytesIO(file_content), "rb") as wav_file: + sample_rate = wav_file.getframerate() + sample_width = wav_file.getsampwidth() + channels = wav_file.getnchannels() + audio_data = wav_file.readframes(wav_file.getnframes()) + + # Convert to mono if stereo (same logic as system_controller.py) + if channels == 2: + if sample_width == 2: + audio_array = np.frombuffer(audio_data, dtype=np.int16) + audio_array = audio_array.reshape(-1, 2) + audio_data = np.mean(audio_array, axis=1, dtype=np.int16).tobytes() + channels = 1 + else: + raise ValueError(f"Unsupported sample width for stereo: {sample_width}") + + # Validate format matches expected (16kHz, mono, 16-bit) + if sample_rate != 16000: + raise ValueError(f"Audio file has sample rate {sample_rate}Hz, expected 16kHz") + if channels != 1: + raise ValueError(f"Audio file has {channels} channels, expected mono") + if sample_width != 2: + raise ValueError(f"Audio file has {sample_width}-byte samples, expected 2 bytes") + + # Create AudioChunk with current timestamp + chunk = AudioChunk( + audio=audio_data, + rate=sample_rate, + width=sample_width, + channels=channels, + timestamp=int(time.time() * 1000) + ) + + logger.info(f"Loaded audio file {audio_path} as AudioChunk ({len(audio_data)} bytes)") + return chunk + + except FileNotFoundError: + logger.error(f"Audio file not found: {audio_path}") + raise + except Exception as e: + logger.error(f"Error loading audio file {audio_path}: {e}") + raise ValueError(f"Invalid audio file format: {e}") + + async def _process_audio_cropping_with_relative_timestamps( original_path: str, speech_segments: list[tuple[float, float]], diff --git a/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py index e53eef88..3df2a281 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/conversation_controller.py @@ -5,12 +5,14 @@ import asyncio import hashlib import logging +import os import time from pathlib import Path from typing import Optional from advanced_omi_backend.audio_utils import ( _process_audio_cropping_with_relative_timestamps, + load_audio_file_as_chunk, ) from advanced_omi_backend.client_manager import ( ClientManager, @@ -18,7 +20,8 @@ get_user_clients_all, ) from advanced_omi_backend.database import AudioChunksRepository, ProcessingRunsRepository, chunks_col, processing_runs_col, conversations_col, ConversationsRepository -from advanced_omi_backend.users import User +from advanced_omi_backend.processors import get_processor_manager, TranscriptionItem, MemoryProcessingItem +from advanced_omi_backend.users import User, get_user_by_id from fastapi.responses import JSONResponse logger = logging.getLogger(__name__) @@ -585,9 +588,10 @@ async def reprocess_transcript(conversation_id: str, user: User): ) # Generate configuration hash for duplicate detection + transcription_provider = os.getenv("TRANSCRIPTION_PROVIDER", "deepgram") config_data = { "audio_path": str(full_audio_path), - "transcription_provider": "deepgram", # This would come from settings + "transcription_provider": transcription_provider, "trigger": "manual_reprocess" } config_hash = hashlib.sha256(str(config_data).encode()).hexdigest()[:16] @@ -613,18 +617,37 @@ async def reprocess_transcript(conversation_id: str, user: User): status_code=500, content={"error": "Failed to create transcript version"} ) - # TODO: Queue audio for reprocessing with ProcessorManager - # This is where we would integrate with the existing processor - # For now, we'll return the version ID for the caller to handle + # NEW: Load audio file and queue for transcription processing + try: + # Load audio file as AudioChunk + audio_chunk = await load_audio_file_as_chunk(full_audio_path) + + # Create TranscriptionItem for reprocessing + transcription_item = TranscriptionItem( + client_id=f"reprocess-{conversation_id}", + user_id=str(user.user_id), + audio_uuid=audio_uuid, + audio_chunk=audio_chunk + ) + + # Queue for transcription processing + processor_manager = get_processor_manager() + await processor_manager.queue_transcription(transcription_item) + + logger.info(f"Queued transcript reprocessing job {run_id} (version {version_id}) for conversation {conversation_id}") - logger.info(f"Created transcript reprocessing job {run_id} (version {version_id}) for conversation {conversation_id}") + except Exception as e: + logger.error(f"Error queuing transcript reprocessing: {e}") + return JSONResponse( + status_code=500, content={"error": f"Failed to queue reprocessing: {str(e)}"} + ) return JSONResponse(content={ "message": f"Transcript reprocessing started for conversation {conversation_id}", "run_id": run_id, "version_id": version_id, "config_hash": config_hash, - "status": "PENDING" + "status": "QUEUED" }) except Exception as e: @@ -673,9 +696,10 @@ async def reprocess_memory(conversation_id: str, transcript_version_id: str, use ) # Generate configuration hash for duplicate detection + memory_provider = os.getenv("MEMORY_PROVIDER", "friend_lite") config_data = { "transcript_version_id": transcript_version_id, - "memory_provider": "friend_lite", # This would come from settings + "memory_provider": memory_provider, "trigger": "manual_reprocess" } config_hash = hashlib.sha256(str(config_data).encode()).hexdigest()[:16] @@ -702,10 +726,34 @@ async def reprocess_memory(conversation_id: str, transcript_version_id: str, use status_code=500, content={"error": "Failed to create memory version"} ) - # TODO: Queue memory extraction for processing - # This is where we would integrate with the existing memory processor + # NEW: Queue memory processing + try: + # Get user email for memory processing + user_obj = await get_user_by_id(str(user.user_id)) + if not user_obj: + return JSONResponse( + status_code=500, content={"error": "User not found for memory processing"} + ) + + # Create MemoryProcessingItem for reprocessing + memory_item = MemoryProcessingItem( + client_id=f"reprocess-{conversation_id}", + user_id=str(user.user_id), + user_email=user_obj.email, + conversation_id=conversation_id + ) + + # Queue for memory processing + processor_manager = get_processor_manager() + await processor_manager.queue_memory(memory_item) - logger.info(f"Created memory reprocessing job {run_id} (version {version_id}) for conversation {conversation_id}") + logger.info(f"Queued memory reprocessing job {run_id} (version {version_id}) for conversation {conversation_id}") + + except Exception as e: + logger.error(f"Error queuing memory reprocessing: {e}") + return JSONResponse( + status_code=500, content={"error": f"Failed to queue memory reprocessing: {str(e)}"} + ) return JSONResponse(content={ "message": f"Memory reprocessing started for conversation {conversation_id}", @@ -713,7 +761,7 @@ async def reprocess_memory(conversation_id: str, transcript_version_id: str, use "version_id": version_id, "transcript_version_id": transcript_version_id, "config_hash": config_hash, - "status": "PENDING" + "status": "QUEUED" }) except Exception as e: diff --git a/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py index d863985f..095c6801 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py @@ -523,9 +523,14 @@ async def list_processing_jobs(): async def process_files_with_content( job_id: str, file_data: list[tuple[str, bytes]], user: User, device_name: str ): - """Background task to process uploaded files using pre-read content.""" + """Background task to process uploaded files using pre-read content. + + Creates persistent clients that remain active in an upload session, + following the same code path as WebSocket clients. + """ # Import here to avoid circular imports - from advanced_omi_backend.main import cleanup_client_state, create_client_state + from advanced_omi_backend.main import create_client_state, cleanup_client_state + import uuid audio_logger.info( f"๐Ÿš€ process_files_with_content called for job {job_id} with {len(file_data)} files" @@ -536,8 +541,13 @@ async def process_files_with_content( # Update job status to processing await job_tracker.update_job_status(job_id, JobStatus.PROCESSING) + # Process files one by one + processed_files = [] + for file_index, (filename, content) in enumerate(file_data): - client_id = None + # Generate client ID for this file + file_device_name = f"{device_name}-{file_index + 1:03d}" + client_id = generate_client_id(user, file_device_name) client_state = None try: @@ -577,18 +587,22 @@ async def process_files_with_content( ) continue - # Generate unique client ID for each file + # Use pre-generated client ID from upload session file_device_name = f"{device_name}-{file_index + 1:03d}" - client_id = generate_client_id(user, file_device_name) # Update job tracker with client ID await job_tracker.update_file_status( job_id, filename, FileStatus.PROCESSING, client_id=client_id ) - # Create client state + # Create persistent client state (will be tracked by ProcessorManager) client_state = await create_client_state(client_id, user, file_device_name) + + audio_logger.info( + f"๐Ÿ‘ค [Job {job_id}] Created persistent client {client_id} for file {filename}" + ) + # Process WAV file with wave.open(io.BytesIO(content), "rb") as wav_file: sample_rate = wav_file.getframerate() @@ -732,21 +746,23 @@ async def process_files_with_content( job_id, filename, FileStatus.FAILED, error_message=error_msg ) finally: - # Always clean up client state to prevent accumulation + # Clean up client state immediately after upload completes (like WebSocket disconnect) + # ProcessorManager will continue tracking processing independently if client_id and client_state: try: await cleanup_client_state(client_id) - audio_logger.info( - f"๐Ÿงน [Job {job_id}] Cleaned up client state for {client_id}" - ) + audio_logger.info(f"๐Ÿงน Cleaned up client state for {client_id}") except Exception as cleanup_error: audio_logger.error( - f"โŒ [Job {job_id}] Error cleaning up client state for {client_id}: {cleanup_error}" + f"โŒ Error cleaning up client state for {client_id}: {cleanup_error}" ) # Mark job as completed await job_tracker.update_job_status(job_id, JobStatus.COMPLETED) - audio_logger.info(f"๐ŸŽ‰ [Job {job_id}] All files processed") + + audio_logger.info( + f"๐ŸŽ‰ [Job {job_id}] All files processed successfully." + ) except Exception as e: error_msg = f"Job processing failed: {str(e)}" @@ -754,6 +770,7 @@ async def process_files_with_content( await job_tracker.update_job_status(job_id, JobStatus.FAILED, error_msg) + # Configuration functions moved to config.py to avoid circular imports @@ -1282,3 +1299,6 @@ async def get_client_processing_detail(client_id: str): return JSONResponse( status_code=500, content={"error": f"Failed to get client detail: {str(e)}"} ) + + + diff --git a/backends/advanced/src/advanced_omi_backend/main.py b/backends/advanced/src/advanced_omi_backend/main.py index 1eaafabe..f463f29d 100644 --- a/backends/advanced/src/advanced_omi_backend/main.py +++ b/backends/advanced/src/advanced_omi_backend/main.py @@ -273,6 +273,14 @@ async def cleanup_client_state(client_id: str): removed = await client_manager.remove_client_with_cleanup(client_id) if removed: + # Clean up processor manager task tracking + try: + processor_manager = get_processor_manager() + processor_manager.cleanup_processing_tasks(client_id) + logger.debug(f"Cleaned up processor tasks for client {client_id}") + except Exception as processor_cleanup_error: + logger.error(f"Error cleaning up processor tasks for {client_id}: {processor_cleanup_error}") + # Clean up any orphaned transcript events for this client coordinator = get_transcript_coordinator() coordinator.cleanup_transcript_events_for_client(client_id) @@ -320,6 +328,7 @@ async def lifespan(app: FastAPI): processor_manager = init_processor_manager(CHUNK_DIR, ac_repository) await processor_manager.start() + logger.info("App ready") try: yield @@ -331,6 +340,7 @@ async def lifespan(app: FastAPI): for client_id in client_manager.get_all_client_ids(): await cleanup_client_state(client_id) + # Shutdown processor manager processor_manager = get_processor_manager() await processor_manager.shutdown() diff --git a/backends/advanced/src/advanced_omi_backend/memory/memory_service.py b/backends/advanced/src/advanced_omi_backend/memory/memory_service.py index dc5bc21e..9518d6e1 100644 --- a/backends/advanced/src/advanced_omi_backend/memory/memory_service.py +++ b/backends/advanced/src/advanced_omi_backend/memory/memory_service.py @@ -176,11 +176,13 @@ async def add_memory( created_ids: List[str] = [] # If allow_update, try LLM-driven action proposal + update_processing_successful = False if allow_update and fact_memories_text: memory_logger.info(f"๐Ÿ” Allowing update for {source_id}") created_ids = await self._process_memory_updates( fact_memories_text, embeddings, user_id, client_id, source_id, user_email ) + update_processing_successful = True else: memory_logger.info(f"๐Ÿ” Not allowing update for {source_id}") # Add all extracted memories normally @@ -197,9 +199,15 @@ async def add_memory( if created_ids and db_helper: await self._update_database_relationships(db_helper, source_id, created_ids) + # Success conditions: + # 1. Normal path: created_ids > 0 (memories were added/updated) + # 2. Update path: LLM successfully processed actions (even if all NONE) if created_ids: memory_logger.info(f"โœ… Upserted {len(created_ids)} memories for {source_id}") return True, created_ids + elif update_processing_successful: + memory_logger.info(f"โœ… Memory update processing completed for {source_id} - LLM decided no changes needed") + return True, [] error_msg = f"โŒ No memories created for {source_id}: memory_entries={len(memory_entries) if memory_entries else 0}, allow_update={allow_update}" memory_logger.error(error_msg) diff --git a/backends/advanced/src/advanced_omi_backend/processors.py b/backends/advanced/src/advanced_omi_backend/processors.py index 4a7343d3..67ea82a9 100644 --- a/backends/advanced/src/advanced_omi_backend/processors.py +++ b/backends/advanced/src/advanced_omi_backend/processors.py @@ -429,10 +429,23 @@ def get_processing_status(self, client_id: str) -> dict[str, Any]: # Check if all stages are complete all_complete = all(stage_info["completed"] for stage_info in stages.values()) + # Get user_id for the client from ClientManager + from advanced_omi_backend.client_manager import get_client_owner + user_id = get_client_owner(client_id) or "Unknown" + + # Determine client type (simple heuristic based on client_id pattern) + # Upload clients have pattern like: "abc123-upload-001", "abc123-upload-001-2", etc. + # They contain "-upload-" in their client_id + # Reprocessing clients have pattern like: "reprocess-{conversation_id}" and should be treated like upload clients + import re + client_type = "upload" if ("-upload-" in client_id or client_id.startswith("reprocess-")) else "websocket" + return { "status": "complete" if all_complete else "processing", "stages": stages, "client_id": client_id, + "user_id": user_id, + "client_type": client_type, } def cleanup_processing_tasks(self, client_id: str): @@ -445,6 +458,167 @@ def cleanup_processing_tasks(self, client_id: str): del self.processing_state[client_id] logger.debug(f"Cleaned up processing state for client {client_id}") + def _is_stale(self, client_id: str, max_idle_minutes: int = 30) -> bool: + """Check if a processing entry is stale (no activity for specified time). + + Args: + client_id: Client ID to check + max_idle_minutes: Maximum idle time in minutes before considering stale + + Returns: + True if the entry is stale and should be cleaned up + """ + import time + + max_idle_seconds = max_idle_minutes * 60 + current_time = time.time() + + # Check processing_state timestamps + if client_id in self.processing_state: + client_state = self.processing_state[client_id] + # Find the most recent timestamp across all stages + latest_timestamp = 0 + for stage_info in client_state.values(): + if isinstance(stage_info, dict) and "timestamp" in stage_info: + latest_timestamp = max(latest_timestamp, stage_info["timestamp"]) + + if latest_timestamp > 0: + idle_time = current_time - latest_timestamp + return idle_time > max_idle_seconds + + # If no processing_state or no valid timestamps, consider it stale + return True + + def _cleanup_completed_entries(self): + """Clean up completed and stale processing entries independently of client lifecycle. + + This method is called from existing processor timeout handlers to maintain + clean processing state without affecting active client sessions. + """ + import time + + clients_to_remove = [] + current_time = time.time() + + for client_id in list(self.processing_state.keys()): + try: + status = self.get_processing_status(client_id) + + # Clean up if processing is complete OR if upload client is done (even with failed stages) + client_type = status.get("client_type", "websocket") + + if status.get("status") == "complete": + if client_type == "upload": + # Upload clients: Clean up immediately when processing completes + clients_to_remove.append((client_id, "completed_upload")) + logger.info(f"Marking completed upload client for immediate cleanup: {client_id}") + + # Also trigger client state cleanup for upload clients + try: + from advanced_omi_backend.main import cleanup_client_state + import asyncio + + # Schedule client cleanup + asyncio.create_task(self._cleanup_upload_client_state(client_id)) + except Exception as cleanup_error: + logger.error(f"Error scheduling upload client cleanup for {client_id}: {cleanup_error}") + else: + # WebSocket clients: Wait for grace period before cleanup + completion_grace_period = 300 # 5 minutes + + # Check if all stages have been complete for grace period + all_stages_old_enough = True + for stage_info in status.get("stages", {}).values(): + if "timestamp" in stage_info: + stage_age = current_time - stage_info["timestamp"] + if stage_age < completion_grace_period: + all_stages_old_enough = False + break + + if all_stages_old_enough: + clients_to_remove.append((client_id, "completed_websocket")) + logger.info(f"Marking completed WebSocket client for cleanup: {client_id}") + + elif client_type == "upload" and status.get("status") == "processing": + # Upload clients: Also clean up if they're done processing (even with failed stages) + # Check if all stages are either completed or have failed (i.e., no longer actively processing) + stages = status.get("stages", {}) + all_stages_done = True + + for stage_name, stage_info in stages.items(): + if not stage_info.get("completed", False) and stage_info.get("status") not in ["failed", "completed"]: + all_stages_done = False + break + + if all_stages_done: + clients_to_remove.append((client_id, "finished_upload")) + logger.info(f"Marking finished upload client for cleanup: {client_id} (some stages may have failed)") + + # Also trigger client state cleanup for upload clients + try: + from advanced_omi_backend.main import cleanup_client_state + import asyncio + + # Schedule client cleanup + asyncio.create_task(self._cleanup_upload_client_state(client_id)) + except Exception as cleanup_error: + logger.error(f"Error scheduling upload client cleanup for {client_id}: {cleanup_error}") + + # Clean up if stale (no activity for 30+ minutes) + elif self._is_stale(client_id, max_idle_minutes=30): + clients_to_remove.append((client_id, "stale")) + logger.info(f"Marking stale processing entry for cleanup: {client_id}") + + except Exception as e: + logger.error(f"Error checking processing status for {client_id}: {e}") + # If we can't check status, consider it for cleanup + clients_to_remove.append((client_id, "error")) + + # Remove the identified entries + for client_id, reason in clients_to_remove: + try: + self._remove_processing_entry(client_id, reason) + except Exception as e: + logger.error(f"Error removing processing entry for {client_id}: {e}") + + async def _cleanup_upload_client_state(self, client_id: str): + """Clean up client state for completed upload clients. + + This method handles the client state cleanup that was previously done + in the background task's finally block, but now happens when processing completes. + """ + try: + from advanced_omi_backend.main import cleanup_client_state + + logger.info(f"๐Ÿงน Starting upload client state cleanup for {client_id}") + await cleanup_client_state(client_id) + logger.info(f"โœ… Successfully cleaned up upload client state for {client_id}") + + except Exception as e: + logger.error(f"โŒ Error cleaning up upload client state for {client_id}: {e}", exc_info=True) + + def _remove_processing_entry(self, client_id: str, reason: str = "cleanup"): + """Remove processing state and task tracking for a client. + + Args: + client_id: Client ID to remove + reason: Reason for removal (for logging) + """ + removed_items = [] + + if client_id in self.processing_state: + del self.processing_state[client_id] + removed_items.append("processing_state") + + if client_id in self.processing_tasks: + del self.processing_tasks[client_id] + removed_items.append("processing_tasks") + + if removed_items: + logger.info(f"๐Ÿงน Cleaned up processing entry for {client_id} ({reason}): {', '.join(removed_items)}") + else: + logger.debug(f"No processing entry found to clean up for {client_id} ({reason})") + def get_all_processing_status(self) -> dict[str, Any]: """Get processing status for all clients.""" # Get all client IDs from both tracking types @@ -815,7 +989,7 @@ async def _audio_processor(self): ) except asyncio.TimeoutError: - # Periodic health check + # Periodic health check and cleanup active_clients = len(self.active_file_sinks) queue_size = self.audio_queue.qsize() if queue_size > 0 or active_clients > 0: @@ -824,6 +998,12 @@ async def _audio_processor(self): f"{queue_size} items in queue" ) + # Perform cleanup of completed/stale processing entries + try: + self._cleanup_completed_entries() + except Exception as cleanup_error: + audio_logger.error(f"Error during processing entry cleanup: {cleanup_error}") + except Exception as e: audio_logger.error(f"Fatal error in audio processor: {e}", exc_info=True) finally: diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py index 494db6ce..21534a6f 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py @@ -189,3 +189,5 @@ async def get_client_processing_detail_route( ): """Get detailed processing information for specific client. Admin only.""" return await system_controller.get_client_processing_detail(client_id) + + diff --git a/backends/advanced/webui/src/pages/Processes.tsx b/backends/advanced/webui/src/pages/Processes.tsx index 0eaf050f..67a9733c 100644 --- a/backends/advanced/webui/src/pages/Processes.tsx +++ b/backends/advanced/webui/src/pages/Processes.tsx @@ -1,5 +1,5 @@ import { useState, useEffect } from 'react' -import { Activity, RefreshCw, Users, Clock, BarChart3 } from 'lucide-react' +import { Activity, RefreshCw } from 'lucide-react' import { systemApi } from '../services/api' import { useAuth } from '../contexts/AuthContext' import ProcessPipelineView from '../components/processes/ProcessPipelineView' @@ -45,26 +45,6 @@ interface ProcessingHistoryItem { error?: string } -interface ClientProcessingDetail { - client_id: string - client_info: { - user_id: string - user_email: string - current_audio_uuid?: string - conversation_start_time?: string - sample_rate?: number - } - processing_status: any - active_tasks: Array<{ - task_id: string - task_name: string - task_type: string - created_at: string - completed_at?: string - error?: string - cancelled: boolean - }> -} export default function Processes() { const [overviewData, setOverviewData] = useState(null) diff --git a/backends/advanced/webui/src/pages/System.tsx b/backends/advanced/webui/src/pages/System.tsx index c1283660..9c1b34eb 100644 --- a/backends/advanced/webui/src/pages/System.tsx +++ b/backends/advanced/webui/src/pages/System.tsx @@ -144,9 +144,6 @@ export default function System() { return displayNames[service] || service.replace('_', ' ').toUpperCase() } - const formatDate = (dateString: string) => { - return new Date(dateString).toLocaleString() - } if (!isAdmin) { return ( diff --git a/backends/advanced/webui/src/pages/Upload.tsx b/backends/advanced/webui/src/pages/Upload.tsx index 04e7d24c..b77005b4 100644 --- a/backends/advanced/webui/src/pages/Upload.tsx +++ b/backends/advanced/webui/src/pages/Upload.tsx @@ -1,6 +1,6 @@ -import React, { useState, useCallback } from 'react' +import React, { useState, useCallback, useEffect } from 'react' import { Upload as UploadIcon, File, X, CheckCircle, AlertCircle, RefreshCw } from 'lucide-react' -import { uploadApi } from '../services/api' +import { uploadApi, systemApi } from '../services/api' import { useAuth } from '../contexts/AuthContext' interface UploadFile { @@ -10,11 +10,63 @@ interface UploadFile { error?: string } +// Legacy JobStatus interface - kept for backward compatibility +interface JobStatus { + job_id: string + status: 'pending' | 'processing' | 'completed' | 'failed' + total_files: number + processed_files: number + current_file?: string + progress_percent: number + files?: Array<{ + filename: string + client_id: string + status: 'pending' | 'processing' | 'completed' | 'failed' + transcription_status?: string + memory_status?: string + error_message?: string + }> +} + +// New unified processing interfaces +interface ProcessingTask { + client_id: string + user_id: string + status: 'processing' | 'complete' + stages: Record +} + +// UploadSessionData interface removed - replaced by unified processor tasks polling + +interface UploadSession { + job_id: string + file_names: string[] + started_at: number + upload_completed: boolean + total_files: number +} + export default function Upload() { const [files, setFiles] = useState([]) - const [isUploading, setIsUploading] = useState(false) const [dragActive, setDragActive] = useState(false) + + // Three-phase state management + const [uploadPhase, setUploadPhase] = useState<'idle' | 'uploading' | 'completed'>('idle') const [uploadProgress, setUploadProgress] = useState(0) + const [processingPhase, setProcessingPhase] = useState<'idle' | 'starting' | 'active' | 'completed'>('idle') + const [jobStatus, setJobStatus] = useState(null) + const [processingTasks, setProcessingTasks] = useState([]) + + // Polling configuration + const [autoRefresh, setAutoRefresh] = useState(true) + const [refreshInterval, setRefreshInterval] = useState(2000) // 2s default for upload page + const [isPolling, setIsPolling] = useState(false) const { isAdmin } = useAuth() @@ -61,10 +113,146 @@ export default function Upload() { handleFileSelect(e.dataTransfer.files) }, []) + // localStorage persistence + const saveSession = (session: UploadSession) => { + localStorage.setItem('upload_session', JSON.stringify(session)) + } + + const getStoredSession = (): UploadSession | null => { + const saved = localStorage.getItem('upload_session') + return saved ? JSON.parse(saved) : null + } + + const clearStoredSession = () => { + localStorage.removeItem('upload_session') + } + + // Resume session on page load + useEffect(() => { + const session = getStoredSession() + if (session) { + setProcessingPhase('active') + setIsPolling(true) + // Use unified polling without session dependency + pollProcessingStatus() + } + }, []) + + // Polling effect + useEffect(() => { + if (!autoRefresh || !isPolling) return + + const interval = setInterval(() => { + pollProcessingStatus() + }, refreshInterval) + + return () => clearInterval(interval) + }, [autoRefresh, refreshInterval, isPolling]) + + // New unified polling approach - polls processor tasks directly without session dependency + const pollProcessingStatus = async () => { + try { + // Get all processor tasks + const tasksResponse = await systemApi.getProcessorTasks() + const allTasks = tasksResponse.data + + // Filter for upload clients (identified by client_id pattern ending with 3-digit numbers like "-001", "-002") + const uploadTasks: ProcessingTask[] = Object.entries(allTasks) + .filter(([clientId, taskData]) => { + // Upload clients have pattern like: "abc123-upload-001", "abc123-upload-002" + return /.*-upload-\d{3}$/.test(clientId) + }) + .map(([clientId, taskData]: [string, any]) => ({ + client_id: clientId, + user_id: taskData?.user_id || 'Unknown', + status: taskData?.status || 'processing', + stages: taskData?.stages || {} + })) + .filter(task => Object.keys(task.stages).length > 0) // Only show clients with active processing + + setProcessingTasks(uploadTasks) + + // Check if all clients are complete OR no upload tasks exist (meaning processing finished) + const allComplete = uploadTasks.length > 0 && uploadTasks.every(task => task.status === 'complete') + const noActiveTasks = uploadTasks.length === 0 && processingPhase === 'active' + + if (allComplete || noActiveTasks) { + setIsPolling(false) + setProcessingPhase('completed') + clearStoredSession() + + setFiles(prevFiles => + prevFiles.map(f => ({ + ...f, + status: 'success' + })) + ) + } else if (uploadTasks.some(task => Object.values(task.stages).some(stage => stage.error))) { + // Check for any errors in processing stages + const hasErrors = uploadTasks.some(task => + Object.values(task.stages).some(stage => stage.error) + ) + + if (hasErrors) { + setFiles(prevFiles => + prevFiles.map(f => ({ + ...f, + status: 'error', + error: 'Processing failed' + })) + ) + } + } + } catch (error) { + console.error('Failed to poll processing status:', error) + } + } + + // Legacy job polling for backward compatibility + const pollJobStatus = async (jobId: string) => { + try { + // Use new unified polling (no session dependency) + await pollProcessingStatus() + + // Also get legacy job status for progress display (if available) + try { + const response = await uploadApi.getJobStatus(jobId) + const status: JobStatus = response.data + setJobStatus(status) + } catch (jobError) { + console.log('Legacy job status not available, using unified polling only') + } + } catch (error) { + console.error('Failed to poll unified processing status:', error) + // Fallback to legacy job polling + try { + const response = await uploadApi.getJobStatus(jobId) + const status: JobStatus = response.data + setJobStatus(status) + + if (status.status === 'completed' || status.status === 'failed') { + setIsPolling(false) + setProcessingPhase('completed') + clearStoredSession() + + setFiles(prevFiles => + prevFiles.map(f => ({ + ...f, + status: status.status === 'completed' ? 'success' : 'error' + })) + ) + } + } catch (fallbackError) { + console.error('All polling methods failed:', fallbackError) + } + } + } + const uploadFiles = async () => { if (files.length === 0) return - setIsUploading(true) + // Phase 1: File Upload + setUploadPhase('uploading') setUploadProgress(0) try { @@ -74,38 +262,66 @@ export default function Upload() { }) // Update all files to uploading status - setFiles(prevFiles => + setFiles(prevFiles => prevFiles.map(f => ({ ...f, status: 'uploading' as const })) ) - await uploadApi.uploadAudioFiles(formData, (progress) => { + // Phase 1: Upload files and get job ID + const response = await uploadApi.uploadAudioFilesAsync(formData, (progress) => { setUploadProgress(progress) }) - - // Mark all files as successful - setFiles(prevFiles => - prevFiles.map(f => ({ ...f, status: 'success' as const })) - ) + + // Phase 2: Job Creation + setUploadPhase('completed') + setProcessingPhase('starting') + + const jobData = response.data + const jobId = jobData.job_id || jobData.jobs?.[0]?.job_id + + if (!jobId) { + throw new Error('No job ID received from server') + } + + // Save session for disconnection handling + const session: UploadSession = { + job_id: jobId, + file_names: files.map(f => f.file.name), + started_at: Date.now(), + upload_completed: true, + total_files: files.length + } + saveSession(session) + + // Phase 3: Start polling for processing status + setProcessingPhase('active') + setIsPolling(true) + pollJobStatus(jobId) } catch (error: any) { console.error('Upload failed:', error) - + + setUploadPhase('idle') + setProcessingPhase('idle') + // Mark all files as failed - setFiles(prevFiles => - prevFiles.map(f => ({ - ...f, - status: 'error' as const, - error: error.message || 'Upload failed' + setFiles(prevFiles => + prevFiles.map(f => ({ + ...f, + status: 'error' as const, + error: error.message || 'Upload failed' })) ) - } finally { - setIsUploading(false) - setUploadProgress(100) } } const clearCompleted = () => { setFiles(files.filter(f => f.status === 'pending' || f.status === 'uploading')) + if (processingPhase === 'completed') { + setProcessingPhase('idle') + setUploadPhase('idle') + setJobStatus(null) + clearStoredSession() + } } const formatFileSize = (bytes: number) => { @@ -205,10 +421,13 @@ export default function Upload() { @@ -261,12 +480,12 @@ export default function Upload() { )} - {/* Upload Progress */} - {isUploading && ( + {/* Phase 1: Upload Progress */} + {uploadPhase === 'uploading' && (
- Processing audio files... + Uploading files... ({files.length} files) {uploadProgress}% @@ -278,9 +497,124 @@ export default function Upload() { style={{ width: `${uploadProgress}%` }} />
-

- Note: Processing may take up to 5 minutes depending on file size and quantity. -

+
+ )} + + {/* Phase 2: Job Creation */} + {processingPhase === 'starting' && ( +
+
+ + Files uploaded. Starting processing jobs... + + +
+
+ )} + + {/* Phase 3: Processing Status with Configurable Refresh */} + {processingPhase === 'active' && jobStatus && ( +
+ {/* Refresh Controls */} +
+
+ + + +
+ + +
+ + {/* Processing Status */} +
+
+ + Processing file {jobStatus.processed_files + 1}/{jobStatus.total_files} + {jobStatus.current_file && `: ${jobStatus.current_file}`} + + + {Math.round(jobStatus.progress_percent)}% + +
+ +
+
+
+ +

+ Processing may take up to 3x audio duration + 60s. Status updates every {refreshInterval/1000}s. +

+
+ + {/* Per-File Status */} + {jobStatus.files && jobStatus.files.length > 0 && ( +
+

File Processing Status

+
+ {jobStatus.files.map((file, index) => ( +
+ + {file.filename} + +
+ + {file.status.charAt(0).toUpperCase() + file.status.slice(1)} + + {file.status === 'processing' && ( + + )} +
+
+ ))} +
+
+ )} +
+ )} + + {/* Completion Status */} + {processingPhase === 'completed' && ( +
+
+ + + All files processed successfully! Check the Conversations tab to see results. + +
)} @@ -290,10 +624,12 @@ export default function Upload() { ๐Ÿ“ Upload Instructions
    -
  • โ€ข Audio files will be processed sequentially for transcription and memory extraction
  • -
  • โ€ข Processing time varies based on audio length (roughly 3x the audio duration + 60s)
  • -
  • โ€ข Large files or multiple files may cause timeout errors - this is normal
  • -
  • โ€ข Check the Conversations tab to see processed results
  • +
  • โ€ข Phase 1: Files upload quickly to server (progress bar shows transfer)
  • +
  • โ€ข Phase 2: Processing jobs created (immediate)
  • +
  • โ€ข Phase 3: Audio processing (transcription + memory extraction, ~3x audio duration)
  • +
  • โ€ข You can safely navigate away - processing continues in background
  • +
  • โ€ข Refresh rate is configurable (0.5s to 10s) during processing
  • +
  • โ€ข Check Conversations tab for final results
  • โ€ข Supported formats: WAV, MP3, M4A, FLAC
diff --git a/backends/advanced/webui/src/services/api.ts b/backends/advanced/webui/src/services/api.ts index 9da281e6..32dec703 100644 --- a/backends/advanced/webui/src/services/api.ts +++ b/backends/advanced/webui/src/services/api.ts @@ -141,7 +141,7 @@ export const systemApi = { } export const uploadApi = { - uploadAudioFiles: (files: FormData, onProgress?: (progress: number) => void) => + uploadAudioFiles: (files: FormData, onProgress?: (progress: number) => void) => api.post('/api/process-audio-files', files, { headers: { 'Content-Type': 'multipart/form-data' }, timeout: 300000, // 5 minutes @@ -152,6 +152,27 @@ export const uploadApi = { } } }), + + // Async upload using existing infrastructure - returns job IDs for monitoring + uploadAudioFilesAsync: (files: FormData, onUploadProgress?: (progress: number) => void) => + api.post('/api/process-audio-files-async', files, { + headers: { 'Content-Type': 'multipart/form-data' }, + timeout: 300000, // 5 minutes for upload phase + onUploadProgress: (progressEvent) => { + if (onUploadProgress && progressEvent.total) { + const progress = Math.round((progressEvent.loaded * 100) / progressEvent.total) + onUploadProgress(progress) + } + } + }), + + // Get job status for a specific job + getJobStatus: (jobId: string) => + api.get(`/api/process-audio-files/jobs/${jobId}`), + + // Get status for multiple jobs + getJobStatuses: (jobIds: string[]) => + Promise.all(jobIds.map(jobId => uploadApi.getJobStatus(jobId))) } export const chatApi = { @@ -205,4 +226,6 @@ export const speakerApi = { // Check speaker service status (admin only) getSpeakerServiceStatus: () => api.get('/api/speaker-service-status'), -} \ No newline at end of file +} + +// Upload session API removed - functionality replaced by unified processor tasks polling diff --git a/extras/speaker-recognition/sortformer.py b/extras/speaker-recognition/sortformer.py new file mode 100644 index 00000000..d1990fd1 --- /dev/null +++ b/extras/speaker-recognition/sortformer.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 +""" +Test script for NVIDIA SortFormer diarization model with speaker enrollment. +Tests on conversation and enrollment audio files, then maps diarized tracks to enrolled speakers. +""" +import os +import sys +import wave +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import nemo.collections.asr as nemo_asr +import numpy as np +import soundfile as sf +import torch +import torchaudio +from nemo.collections.asr.models import SortformerEncLabelModel + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +TARGET_SR = 16000 + +def get_audio_duration(file_path): + """Get audio duration using wave module.""" + try: + with wave.open(file_path, 'r') as wav_file: + frames = wav_file.getnframes() + sample_rate = wav_file.getframerate() + duration = frames / float(sample_rate) + return duration + except Exception as e: + return 0.0 + +def load_audio_16k_mono(path: str) -> Tuple[torch.Tensor, int]: + """Load audio file and convert to 16kHz mono.""" + wav, sr = torchaudio.load(path) + if wav.shape[0] > 1: + wav = torch.mean(wav, dim=0, keepdim=True) # convert to mono + if sr != TARGET_SR: + wav = torchaudio.functional.resample(wav, sr, TARGET_SR) + return wav.squeeze(0), TARGET_SR + +def write_temp_wav(path: str, wav: torch.Tensor, sr: int = TARGET_SR) -> None: + """Write temporary wav file for embedding extraction.""" + sf.write(path, wav.cpu().numpy(), sr) + +def get_embedding_from_file(speaker_model, file_path: str) -> Optional[torch.Tensor]: + """Extract normalized speaker embedding from audio file.""" + try: + with torch.no_grad(): + emb = speaker_model.get_embedding(file_path) + + # Handle different return types from get_embedding + if isinstance(emb, (list, tuple)): + emb = emb[0] + if isinstance(emb, np.ndarray): + emb = torch.from_numpy(emb) + + emb = emb.float().squeeze().cpu() + # Normalize embedding + return emb / (emb.norm(p=2) + 1e-9) + except Exception as e: + print(f" ERROR extracting embedding from {file_path}: {e}") + return None + +def create_speaker_enrollment(speaker_model, enrollment_files: Dict[str, List[str]]) -> Dict[str, torch.Tensor]: + """Create speaker enrollment centroids from multiple audio files per speaker.""" + enrollment = {} + + print("\n" + "="*60) + print("SPEAKER ENROLLMENT") + print("="*60) + + for speaker_name, file_list in enrollment_files.items(): + print(f"\nEnrolling {speaker_name}...") + embeddings = [] + + for file_path in file_list: + if not os.path.exists(file_path): + print(f" WARNING: {file_path} not found") + continue + + duration = get_audio_duration(file_path) + print(f" Processing {os.path.basename(file_path)} ({duration:.1f}s)...") + + emb = get_embedding_from_file(speaker_model, file_path) + if emb is not None: + embeddings.append(emb) + print(f" โœ“ Embedding extracted (shape: {emb.shape})") + + if embeddings: + # Average embeddings to create centroid + centroid = torch.stack(embeddings, dim=0).mean(dim=0) + centroid = centroid / (centroid.norm(p=2) + 1e-9) # normalize + enrollment[speaker_name] = centroid + print(f" โœ“ {speaker_name} enrolled with {len(embeddings)} samples") + print(f" Centroid shape: {centroid.shape}") + else: + print(f" โœ— Failed to enroll {speaker_name} - no valid embeddings") + + return enrollment + +def extract_segments_embeddings(speaker_model, audio_file: str, segments: List) -> Dict[int, torch.Tensor]: + """Extract embeddings for each diarized speaker track.""" + print("\n" + "="*60) + print("EXTRACTING TRACK EMBEDDINGS") + print("="*60) + + # Load full audio + full_wav, sr = load_audio_16k_mono(audio_file) + + # Group segments by speaker + speaker_segments = {} + for seg in segments: + start, end, spk_idx = float(seg[0]), float(seg[1]), int(seg[2]) + speaker_segments.setdefault(spk_idx, []).append((start, end)) + + # Create temp directory for segment files + temp_dir = "tmp_segments" + os.makedirs(temp_dir, exist_ok=True) + + track_embeddings = {} + + for spk_idx, seg_list in speaker_segments.items(): + print(f"\nProcessing Speaker Track {spk_idx}...") + print(f" Found {len(seg_list)} segments") + + seg_embeddings = [] + + for i, (start_sec, end_sec) in enumerate(seg_list): + # Extract audio segment + start_samp = int(start_sec * TARGET_SR) + end_samp = int(end_sec * TARGET_SR) + segment_wav = full_wav[start_samp:end_samp].clone() + + # Skip very short segments + if segment_wav.numel() < TARGET_SR // 10: # < 0.1 seconds + print(f" Skipping segment {i+1} (too short: {len(segment_wav)/TARGET_SR:.2f}s)") + continue + + # Write temporary file + temp_path = os.path.join(temp_dir, f"spk{spk_idx}_{i:03d}.wav") + write_temp_wav(temp_path, segment_wav, TARGET_SR) + + # Extract embedding + emb = get_embedding_from_file(speaker_model, temp_path) + if emb is not None: + seg_embeddings.append(emb) + print(f" โœ“ Segment {i+1}: {start_sec:.2f}-{end_sec:.2f}s -> embedding extracted") + + # Clean up temp file + try: + os.remove(temp_path) + except: + pass + + if seg_embeddings: + # Average embeddings for this speaker track + track_emb = torch.stack(seg_embeddings, dim=0).mean(dim=0) + track_emb = track_emb / (track_emb.norm(p=2) + 1e-9) # normalize + track_embeddings[spk_idx] = track_emb + print(f" โœ“ Track {spk_idx}: {len(seg_embeddings)} segments -> final embedding") + else: + print(f" โœ— Track {spk_idx}: No valid embeddings extracted") + + # Clean up temp directory + try: + os.rmdir(temp_dir) + except: + pass + + return track_embeddings + +def map_speakers_to_enrollment(track_embeddings: Dict[int, torch.Tensor], + enrollment: Dict[str, torch.Tensor], + similarity_threshold: float = 0.0) -> Dict[int, str]: + """Map diarized speaker tracks to enrolled speaker identities.""" + print("\n" + "="*60) + print("SPEAKER IDENTITY MAPPING") + print("="*60) + + def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float: + """Calculate cosine similarity between two embeddings.""" + return float(torch.dot(a, b) / ((a.norm(p=2) + 1e-9) * (b.norm(p=2) + 1e-9))) + + speaker_mapping = {} + + print(f"Similarity threshold: {similarity_threshold}") + print(f"Available enrolled speakers: {list(enrollment.keys())}") + + for track_idx, track_emb in track_embeddings.items(): + print(f"\nMapping Track {track_idx}:") + + best_match = None + best_similarity = -1.0 + similarities = {} + + # Compare with all enrolled speakers + for speaker_name, enrolled_emb in enrollment.items(): + similarity = cosine_similarity(track_emb, enrolled_emb) + similarities[speaker_name] = similarity + print(f" vs {speaker_name}: {similarity:.4f}") + + if similarity > best_similarity: + best_similarity = similarity + best_match = speaker_name + + # Assign identity based on threshold + if best_similarity >= similarity_threshold and best_match: + speaker_mapping[track_idx] = best_match + print(f" โ†’ Track {track_idx} mapped to: {best_match} (confidence: {best_similarity:.4f})") + else: + speaker_mapping[track_idx] = f"unknown_spk{track_idx}" + print(f" โ†’ Track {track_idx} mapped to: unknown_spk{track_idx} (low confidence: {best_similarity:.4f})") + + return speaker_mapping + +def generate_labeled_segments(segments: List, speaker_mapping: Dict[int, str]) -> List[Dict]: + """Generate final segments with speaker labels.""" + labeled_segments = [] + + for seg in segments: + start, end, spk_idx = float(seg[0]), float(seg[1]), int(seg[2]) + speaker_name = speaker_mapping.get(spk_idx, f"spk{spk_idx}") + + labeled_segments.append({ + "start": start, + "end": end, + "speaker": speaker_name, + "duration": end - start + }) + + return labeled_segments + +def test_sortformer_with_enrollment(): + """Test SortFormer diarization with speaker enrollment and mapping.""" + # Audio file paths + test_files = { + "conversation": "tests/assets/conversation_evan_katelyn_2min.wav", + "evan_enrollment": [ + "tests/assets/evan/evan_001.wav", + "tests/assets/evan/evan_002.wav", + "tests/assets/evan/evan_003.wav", + "tests/assets/evan/evan_004.wav" + ], + "katelyn_enrollment": [ + "tests/assets/katelyn/katelyn_001.wav", + "tests/assets/katelyn/katelyn_002.wav" + ] + } + + # Check if files exist + print("Checking audio files...") + for category, files in test_files.items(): + if isinstance(files, str): + files = [files] + for file_path in files: + if not os.path.exists(file_path): + print(f"WARNING: {file_path} not found") + else: + duration = get_audio_duration(file_path) + print(f"โœ“ {file_path} (duration: {duration:.1f}s)") + + print(f"\nLoading models on {DEVICE}...") + try: + # Load diarization model + diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2").to(DEVICE) + diar_model.eval() + print("โœ“ SortFormer diarization model loaded") + + # Load speaker verification model + speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained("nvidia/speakerverification_en_titanet_large").to(DEVICE) + speaker_model.eval() + print("โœ“ TitaNet speaker embedding model loaded") + + except Exception as e: + print(f"ERROR loading models: {e}") + return + + # Test basic diarization first + conversation_file = test_files["conversation"] + if not os.path.exists(conversation_file): + print(f"ERROR: Conversation file not found: {conversation_file}") + return + + print(f"\n{'='*60}") + print(f"BASIC DIARIZATION TEST: {conversation_file}") + print('='*60) + + try: + segments = diar_model.diarize(audio=conversation_file, batch_size=1) + print(f"\nFound {len(segments)} diarized segments:") + for i, segment in enumerate(segments): + start, end, spk = float(segment[0]), float(segment[1]), int(segment[2]) + print(f" {i+1:2d}: {start:6.2f}-{end:6.2f}s | Speaker {spk} | Duration: {end-start:.2f}s") + + except Exception as e: + print(f"ERROR during diarization: {e}") + return + + # Create speaker enrollment + enrollment_files = { + "Evan": test_files["evan_enrollment"], + "Katelyn": test_files["katelyn_enrollment"] + } + + enrollment = create_speaker_enrollment(speaker_model, enrollment_files) + + if not enrollment: + print("ERROR: No speakers enrolled successfully") + return + + # Extract embeddings for diarized tracks + track_embeddings = extract_segments_embeddings(speaker_model, conversation_file, segments) + + if not track_embeddings: + print("ERROR: No track embeddings extracted") + return + + # Map speaker tracks to enrolled identities + speaker_mapping = map_speakers_to_enrollment(track_embeddings, enrollment, similarity_threshold=0.3) + + # Generate final labeled segments + labeled_segments = generate_labeled_segments(segments, speaker_mapping) + + # Display results + print("\n" + "="*60) + print("FINAL RESULTS WITH SPEAKER LABELS") + print("="*60) + + print(f"\nLabeled segments ({len(labeled_segments)} total):") + for i, seg in enumerate(labeled_segments): + print(f" {i+1:2d}: {seg['start']:6.2f}-{seg['end']:6.2f}s | {seg['speaker']:12s} | {seg['duration']:.2f}s") + + # Summary by speaker + print(f"\nSpeaker summary:") + speaker_stats = {} + for seg in labeled_segments: + speaker = seg['speaker'] + speaker_stats.setdefault(speaker, {'count': 0, 'total_duration': 0.0}) + speaker_stats[speaker]['count'] += 1 + speaker_stats[speaker]['total_duration'] += seg['duration'] + + for speaker, stats in speaker_stats.items(): + print(f" {speaker:12s}: {stats['count']:2d} segments, {stats['total_duration']:6.1f}s total") + +if __name__ == "__main__": + print("SortFormer Diarization + Speaker Enrollment Test Script") + print("=" * 60) + test_sortformer_with_enrollment() + print("\nTest completed!") \ No newline at end of file