diff --git a/.github/workflows/images.yaml b/.github/workflows/images.yaml index 9c0d7ba3..12e8099c 100644 --- a/.github/workflows/images.yaml +++ b/.github/workflows/images.yaml @@ -25,6 +25,18 @@ jobs: image_name: librechat-rag-api-dev-lite steps: + # Free up disk space + - name: Free Disk Space + uses: jlumbroso/free-disk-space@main + with: + tool-cache: true + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: true + swap-storage: true + # Check out the repository - name: Checkout uses: actions/checkout@v4 @@ -57,3 +69,5 @@ jobs: ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:latest platforms: linux/amd64,linux/arm64 target: ${{ matrix.target }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.gitignore b/.gitignore index 38921790..4c709056 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,7 @@ venv/ *.pyc dev.yml SHOPIFY.md + +# docker override file +docker-compose.override.yaml +docker-compose.override.yml diff --git a/Dockerfile b/Dockerfile index b38d70e2..8204c40e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,7 +15,7 @@ COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt # Download standard NLTK data, to prevent unstructured from downloading packages at runtime -RUN python -m nltk.downloader -d /app/nltk_data punkt_tab averaged_perceptron_tagger +RUN python -m nltk.downloader -d /app/nltk_data punkt_tab averaged_perceptron_tagger averaged_perceptron_tagger_eng ENV NLTK_DATA=/app/nltk_data # Disable Unstructured analytics diff --git a/Dockerfile.lite b/Dockerfile.lite index a40fd464..186360f1 100644 --- a/Dockerfile.lite +++ b/Dockerfile.lite @@ -15,7 +15,7 @@ COPY requirements.lite.txt . RUN pip install --no-cache-dir -r requirements.lite.txt # Download standard NLTK data, to prevent unstructured from downloading packages at runtime -RUN python -m nltk.downloader -d /app/nltk_data punkt_tab averaged_perceptron_tagger +RUN python -m nltk.downloader -d /app/nltk_data punkt_tab averaged_perceptron_tagger averaged_perceptron_tagger_eng ENV NLTK_DATA=/app/nltk_data # Disable Unstructured analytics diff --git a/PaychexDockerfile b/PaychexDockerfile index 9574f1dd..75648821 100644 --- a/PaychexDockerfile +++ b/PaychexDockerfile @@ -2,8 +2,6 @@ FROM python:3.12-slim AS main WORKDIR /app -WORKDIR /app - # Install pandoc and netcat RUN apt-get update \ && apt-get install -y --no-install-recommends \ @@ -17,7 +15,7 @@ COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt # Download standard NLTK data, to prevent unstructured from downloading packages at runtime -RUN python -m nltk.downloader -d /app/nltk_data punkt_tab averaged_perceptron_tagger +RUN python -m nltk.downloader -d /app/nltk_data punkt_tab averaged_perceptron_tagger averaged_perceptron_tagger_eng ENV NLTK_DATA=/app/nltk_data # Disable Unstructured analytics diff --git a/README.md b/README.md index d76aab84..64138fd5 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,33 @@ pip install -r requirements.txt uvicorn main:app ``` +### Clean Install (Local Development) + +To do a clean reinstall of all dependencies (e.g., after updating `requirements.txt`): + +```bash +# Remove existing virtual environment and recreate it +rm -rf venv +python3 -m venv venv +source venv/bin/activate +pip install -r requirements.txt +``` + +For the lite version (without sentence_transformers/huggingface): + +```bash +rm -rf venv +python3 -m venv venv +source venv/bin/activate +pip install -r requirements.lite.txt +``` + +For Docker, rebuild without cache: + +```bash +docker compose build --no-cache +``` + ### Environment Variables The following environment variables are required to run the application: @@ -59,6 +86,8 @@ The following environment variables are required to run the application: - `COLLECTION_NAME`: (Optional) The name of the collection in the vector store. Default value is "testcollection". - `CHUNK_SIZE`: (Optional) The size of the chunks for text processing. Default value is "1500". - `CHUNK_OVERLAP`: (Optional) The overlap between chunks during text processing. Default value is "100". +- `EMBEDDING_BATCH_SIZE`: (Optional) Number of document chunks to process per batch. Set to `0` (default) to disable batching. Recommended value is `750` for `text-embedding-3-small`. +- `EMBEDDING_MAX_QUEUE_SIZE`: (Optional) Maximum number of batches to buffer in memory during async processing. Default value is "3". - `RAG_UPLOAD_DIR`: (Optional) The directory where uploaded files are stored. Default value is "./uploads/". - `PDF_EXTRACT_IMAGES`: (Optional) A boolean value indicating whether to extract images from PDF files. Default value is "False". - `DEBUG_RAG_API`: (Optional) Set to "True" to show more verbose logging output in the server console, and to enable postgresql database routes @@ -71,7 +100,7 @@ The following environment variables are required to run the application: - azure: "text-embedding-3-small" (will be used as your Azure Deployment) - huggingface: "sentence-transformers/all-MiniLM-L6-v2" - huggingfacetei: "http://huggingfacetei:3000". Hugging Face TEI uses model defined on TEI service launch. - - vertexai: "text-embedding-004" + - vertexai: "gemini-embedding-001" - ollama: "nomic-embed-text" - bedrock: "amazon.titan-embed-text-v1" - google_genai: "gemini-embedding-001" @@ -90,11 +119,48 @@ The following environment variables are required to run the application: - `AWS_SECRET_ACCESS_KEY`: (Optional) needed for bedrock embeddings - `GOOGLE_API_KEY`, `GOOGLE_KEY`, `RAG_GOOGLE_API_KEY`: (Optional) Google API key for Google GenAI embeddings. Priority order: RAG_GOOGLE_API_KEY > GOOGLE_KEY > GOOGLE_API_KEY - `AWS_SESSION_TOKEN`: (Optional) may be needed for bedrock embeddings -- `GOOGLE_APPLICATION_CREDENTIALS`: (Optional) needed for Google VertexAI embeddings. This should be a path to a service account credential file in JSON format, as accepted by [langchain](https://python.langchain.com/api_reference/google_vertexai/index.html) +- `GOOGLE_APPLICATION_CREDENTIALS`: (Optional) needed for Google VertexAI embeddings. This should be a path to a service account credential file in JSON format. +- `GOOGLE_CLOUD_PROJECT`: (Optional) Google Cloud project ID, needed for VertexAI embeddings. +- `GOOGLE_CLOUD_LOCATION`: (Optional) Google Cloud region for VertexAI embeddings. Defaults to `us-central1`. - `RAG_CHECK_EMBEDDING_CTX_LENGTH` (Optional) Default is true, disabling this will send raw input to the embedder, use this for custom embedding models. Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables. +### Embedding Batch Processing + +For large files, you can enable batched embedding processing to reduce memory consumption. This is particularly useful in memory-constrained environments like Kubernetes pods with memory limits. + +#### Configuration + +| Variable | Default | Description | +|----------|---------|-------------| +| `EMBEDDING_BATCH_SIZE` | `0` | Number of document chunks to process per batch. `0` disables batching (original behavior). | +| `EMBEDDING_MAX_QUEUE_SIZE` | `3` | Maximum number of batches to buffer in memory during async processing. | + +#### Recommended Settings + +For `text-embedding-3-small` model: +- `EMBEDDING_BATCH_SIZE=750` - Good balance of throughput and memory + +For memory-constrained environments (< 2GB RAM): +- `EMBEDDING_BATCH_SIZE=100-250` + +For high-throughput environments: +- `EMBEDDING_BATCH_SIZE=1000-2000` +- `EMBEDDING_MAX_QUEUE_SIZE=5` + +#### Behavior + +When `EMBEDDING_BATCH_SIZE > 0`: +- Documents are processed in batches of the specified size +- Each batch is embedded and inserted before the next batch starts +- On failure, successfully inserted documents are rolled back +- Memory usage is bounded by `EMBEDDING_BATCH_SIZE * EMBEDDING_MAX_QUEUE_SIZE` + +When `EMBEDDING_BATCH_SIZE = 0` (default): +- All documents are processed at once (original behavior) +- Better for small files or memory-rich environments + ### Use Atlas MongoDB as Vector Database Instead of using the default pgvector, we could use [Atlas MongoDB](https://www.mongodb.com/products/platform/atlas-vector-search) as the vector database. To do so, set the following environment variables @@ -127,6 +193,16 @@ The `ATLAS_MONGO_DB_URI` could be the same or different from what is used by Lib Follow one of the [four documented methods](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure) to create the vector index. +#### Create a `file_id` Index (recommended) + +We recommend creating a standard MongoDB index on `file_id` to keep lookups fast. After creating the collection, run the following once (via Atlas UI, Compass, or `mongosh`): + +```javascript +db.getCollection("").createIndex({ file_id: 1 }) +``` + +Replace `` with the same collection used by the RAG API. This ensures lookups remain fast even as the number of embedded documents grows. + ### Proxy Configuration @@ -169,6 +245,81 @@ Notes: ### Dev notes: +#### Running Tests + +##### Prerequisites + +Install test dependencies: + +```bash +pip install -r test_requirements.txt +``` + +##### Running All Tests + +```bash +# Run all tests +pytest + +# Run with verbose output +pytest -v + +# Run with coverage (if pytest-cov is installed) +pytest --cov=app +``` + +##### Running Specific Test Files + +```bash +# Run batch processing unit tests +pytest tests/test_batch_processing.py -v + +# Run batch processing integration tests (memory optimization tests) +pytest tests/test_batch_processing_integration.py -v + +# Run main API tests +pytest tests/test_main.py -v +``` + +##### Running Tests by Category + +```bash +# Run only integration tests (marked with @pytest.mark.integration) +pytest -m integration -v + +# Skip integration tests +pytest -m "not integration" -v + +# Run only async tests +pytest -k "async" -v +``` + +##### Test Categories + +| Test File | Description | +|-----------|-------------| +| `test_batch_processing.py` | Unit tests for batch processing functions | +| `test_batch_processing_integration.py` | Memory optimization and integration tests | +| `test_main.py` | API endpoint tests | +| `test_config.py` | Configuration tests | +| `test_middleware.py` | Middleware tests | +| `test_models.py` | Model tests | + +##### Memory Optimization Tests + +The `test_batch_processing_integration.py` file includes tests that verify the memory optimization behavior: + +- **`test_memory_bounded_by_batch_size`**: Verifies that the number of documents in memory at any time is bounded by `EMBEDDING_BATCH_SIZE` +- **`test_memory_tracking_with_tracemalloc`**: Uses Python's `tracemalloc` to monitor memory usage during batch processing +- **`test_sync_memory_bounded_by_batch_size`**: Same verification for the synchronous code path + +Run memory tests specifically: + +```bash +pytest tests/test_batch_processing_integration.py::TestMemoryOptimization -v +pytest tests/test_batch_processing_integration.py::TestSyncBatchedMemory -v +``` + #### Installing pre-commit formatter Run the following commands to install pre-commit formatter, which uses [black](https://github.com/psf/black) code formatter: diff --git a/app/config.py b/app/config.py index 27c140c2..fd7a6403 100644 --- a/app/config.py +++ b/app/config.py @@ -71,6 +71,23 @@ def get_env_variable( CHUNK_SIZE = int(get_env_variable("CHUNK_SIZE", "1500")) CHUNK_OVERLAP = int(get_env_variable("CHUNK_OVERLAP", "100")) +# Batch processing configuration for memory-constrained environments. +# When EMBEDDING_BATCH_SIZE > 0, documents are processed in batches to reduce +# peak memory usage. This is useful for Kubernetes pods with memory limits. +# +# Trade-offs: +# - Smaller batch size = lower memory, more DB round trips +# - Larger batch size = higher memory, fewer DB round trips +# - 0 = disable batching, process all at once +# +# Default of 500 is conservative and works well for most embedding providers. +# Increase to 750 for higher throughput at the cost of higher peak memory. +EMBEDDING_BATCH_SIZE = int(get_env_variable("EMBEDDING_BATCH_SIZE", "500")) + +# Maximum number of batches to buffer in memory during async processing. +# Higher values allow more parallelism but use more memory. +EMBEDDING_MAX_QUEUE_SIZE = int(get_env_variable("EMBEDDING_MAX_QUEUE_SIZE", "3")) + env_value = get_env_variable("PDF_EXTRACT_IMAGES", "False").lower() PDF_EXTRACT_IMAGES = True if env_value == "true" else False @@ -241,12 +258,18 @@ def init_embeddings(provider, model): return GoogleGenerativeAIEmbeddings( model=model, - google_api_key=RAG_GOOGLE_API_KEY, + google_api_key=RAG_GOOGLE_API_KEY or None, ) elif provider == EmbeddingsProvider.GOOGLE_VERTEXAI: - from langchain_google_vertexai import VertexAIEmbeddings + from langchain_google_genai import GoogleGenerativeAIEmbeddings - return VertexAIEmbeddings(model=model) + return GoogleGenerativeAIEmbeddings( + model=model, + google_api_key=RAG_GOOGLE_API_KEY or None, + vertexai=True, + project=get_env_variable("GOOGLE_CLOUD_PROJECT", None), + location=get_env_variable("GOOGLE_CLOUD_LOCATION", "us-central1"), + ) elif provider == EmbeddingsProvider.BEDROCK: from langchain_aws import BedrockEmbeddings @@ -290,7 +313,7 @@ def init_embeddings(provider, model): "EMBEDDINGS_MODEL", "http://huggingfacetei:3000" ) elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.GOOGLE_VERTEXAI: - EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-004") + EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "gemini-embedding-001") elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.OLLAMA: EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text") elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.GOOGLE_GENAI: diff --git a/app/routes/document_routes.py b/app/routes/document_routes.py index 5be6dbb3..c5ebd3f5 100644 --- a/app/routes/document_routes.py +++ b/app/routes/document_routes.py @@ -1,11 +1,14 @@ # app/routes/document_routes.py import os +import uuid +from pathlib import Path import hashlib import traceback import aiofiles import aiofiles.os from shutil import copyfileobj -from typing import List, Iterable +from typing import List, Iterable, Optional, Union, TYPE_CHECKING +from concurrent.futures import ThreadPoolExecutor from fastapi import ( APIRouter, Request, @@ -18,11 +21,24 @@ status, ) from langchain_core.documents import Document -from langchain_core.runnables import run_in_executor from langchain_text_splitters import RecursiveCharacterTextSplitter from functools import lru_cache - -from app.config import logger, vector_store, RAG_UPLOAD_DIR, CHUNK_SIZE, CHUNK_OVERLAP +import asyncio + +if TYPE_CHECKING: + from app.services.vector_store.async_pg_vector import AsyncPgVector + from app.services.vector_store.atlas_mongo_vector import AtlasMongoVector + from langchain_community.vectorstores.pgvector import PGVector as PgVector + +from app.config import ( + logger, + vector_store, + RAG_UPLOAD_DIR, + CHUNK_SIZE, + CHUNK_OVERLAP, + EMBEDDING_BATCH_SIZE, + EMBEDDING_MAX_QUEUE_SIZE, +) from app.constants import ERROR_MESSAGES from app.models import ( StoreDocument, @@ -42,6 +58,13 @@ router = APIRouter() +def calculate_num_batches(total: int, batch_size: int) -> int: + """Calculate the number of batches needed to process total items.""" + if batch_size <= 0: + return 1 + return (total + batch_size - 1) // batch_size + + def get_user_id(request: Request, entity_id: str = None) -> str: """Extract user ID from request or entity_id.""" if not hasattr(request.state, "user"): @@ -88,17 +111,47 @@ def save_upload_file_sync(file: UploadFile, temp_file_path: str) -> None: ) +def validate_file_path(base_dir: str, file_path: str) -> Optional[str]: + """Validate that file_path resolves within base_dir. Returns resolved absolute path or None.""" + if not file_path or not file_path.strip(): + return None + try: + allowed = Path(base_dir).resolve() + requested = Path(os.path.join(base_dir, file_path)).resolve() + requested.relative_to(allowed) + return str(requested) + except (ValueError, RuntimeError, TypeError, OSError): + return None + + +def _make_unique_temp_path(user_id: str, filename: str) -> Optional[str]: + """Build a unique temp file path under RAG_UPLOAD_DIR/{user_id}/ to prevent + concurrent upload collisions. Returns a validated absolute path, or None if + the raw filename would escape RAG_UPLOAD_DIR (path traversal rejection).""" + # Validate the raw filename to reject traversal attempts + if validate_file_path(RAG_UPLOAD_DIR, os.path.join(user_id, filename)) is None: + return None + # unique_name is stem + "_" + [0-9a-f]{32} + suffix — no path separators, + # so it cannot escape the directory validated above. + p = Path(filename) + unique_name = f"{p.stem}_{uuid.uuid4().hex}{p.suffix}" + return str(Path(RAG_UPLOAD_DIR, user_id, unique_name).resolve()) + + async def load_file_content( filename: str, content_type: str, file_path: str, executor ) -> tuple: """Load file content using appropriate loader.""" - loader, known_type, file_ext = get_loader(filename, content_type, file_path) - data = await run_in_executor(executor, loader.load) - - # Clean up temporary UTF-8 file if it was created for encoding conversion - cleanup_temp_encoding_file(loader) - - return data, known_type, file_ext + loader = None + try: + loader, known_type, file_ext = get_loader(filename, content_type, file_path) + loop = asyncio.get_running_loop() + data = await loop.run_in_executor(executor, lambda: list(loader.lazy_load())) + return data, known_type, file_ext + finally: + # Clean up temporary UTF-8 file if it was created for encoding conversion + if loader is not None: + cleanup_temp_encoding_file(loader) def extract_text_from_documents(documents: List[Document], file_ext: str) -> str: @@ -279,12 +332,12 @@ async def query_embeddings_by_file_id( documents = await vector_store.asimilarity_search_with_score_by_vector( embedding, k=body.k, - filter={"file_id": body.file_id}, + filter={"file_id": {"$eq": body.file_id}}, executor=request.app.state.thread_pool, ) else: documents = vector_store.similarity_search_with_score_by_vector( - embedding, k=body.k, filter={"file_id": body.file_id} + embedding, k=body.k, filter={"file_id": {"$eq": body.file_id}} ) if not documents: @@ -336,23 +389,265 @@ async def query_embeddings_by_file_id( raise HTTPException(status_code=500, detail=str(e)) -def generate_digest(page_content: str): +async def _process_documents_async_pipeline( + documents: List[Document], + file_id: str, + vector_store: "AsyncPgVector", + executor: "ThreadPoolExecutor", +) -> List[str]: + """ + Process documents using async producer-consumer pattern for batched embedding and insertion. + + Args: + documents: List of Document objects to process + file_id: Unique identifier for the file being processed + vector_store: AsyncPgVector instance for document storage + executor: ThreadPoolExecutor for concurrent operations + + Returns: + List of document IDs that were successfully inserted + """ + total_chunks = len(documents) + if total_chunks == 0: + return [] + + # Create queues for producer-consumer pattern + # embedding_queue is bounded to limit document data held in memory. + # results_queue is unbounded — it holds only small UUID lists, and the + # drain loop runs after gather(), so bounding it would deadlock when + # num_batches > maxsize. + embedding_queue = asyncio.Queue(maxsize=EMBEDDING_MAX_QUEUE_SIZE) + results_queue = asyncio.Queue() + all_ids = [] + + num_batches = calculate_num_batches(total_chunks, EMBEDDING_BATCH_SIZE) + + logger.info( + "Starting async pipeline for file %s: %d chunks with %d batch size", + file_id, + total_chunks, + EMBEDDING_BATCH_SIZE, + ) + + async def batch_producer(): + """Produce document batches and put them in the queue.""" + try: + for batch_idx in range(num_batches): + start_idx = batch_idx * EMBEDDING_BATCH_SIZE + end_idx = min(start_idx + EMBEDDING_BATCH_SIZE, total_chunks) + batch_documents = documents[start_idx:end_idx] + batch_ids = [file_id] * len(batch_documents) + + logger.info( + "Generating embeddings for batch %d/%d: chunks %d-%d", + batch_idx + 1, + num_batches, + start_idx, + end_idx - 1, + ) + + # Put batch in queue for processing + await embedding_queue.put( + (batch_documents, batch_ids, batch_idx + 1, num_batches) + ) + except Exception as e: + logger.error("Error in batch producer: %s", e) + raise + finally: + # Always signal end of production + await embedding_queue.put(None) + + async def embedding_consumer(): + """Consume batches from queue, embed and insert into database.""" + try: + while True: + item = await embedding_queue.get() + if item is None: # End signal + embedding_queue.task_done() + break + + batch_documents, batch_ids, batch_num, total_batches = item + + logger.info( + "Inserting batch %d/%d into database (%d chunks)", + batch_num, + total_batches, + len(batch_documents), + ) + + try: + # Insert batch into database + batch_result_ids = await vector_store.aadd_documents( + batch_documents, ids=batch_ids, executor=executor + ) + await results_queue.put(batch_result_ids) + except Exception as e: + logger.error( + "Error processing batch %d/%d: %s", batch_num, total_batches, e + ) + await results_queue.put(e) # Put exception object + finally: + embedding_queue.task_done() + + except Exception as e: + logger.error("Fatal error in embedding consumer: %s", e) + await results_queue.put(e) + raise + + producer_task = None + consumer_task = None + try: - hash_obj = hashlib.md5(page_content.encode("utf-8")) - except UnicodeEncodeError: - hash_obj = hashlib.md5( - page_content.encode("utf-8", "ignore").decode("utf-8").encode("utf-8") + # Start producer and consumer concurrently + producer_task = asyncio.create_task(batch_producer()) + consumer_task = asyncio.create_task(embedding_consumer()) + + # Wait for both to complete + await asyncio.gather(producer_task, consumer_task, return_exceptions=False) + + # Collect results from all batches + for _ in range(num_batches): + result = await results_queue.get() + if isinstance(result, Exception): + raise result + all_ids.extend(result) + + logger.info( + "Async pipeline completed for file %s: %d embeddings created", + file_id, + len(all_ids), ) - return hash_obj.hexdigest() + return all_ids -async def store_data_in_vector_db( + except Exception as e: + logger.error("Pipeline failed for file %s: %s", file_id, e) + if consumer_task is not None or producer_task is not None: + # if one of the tasks is still running, cancel it + if consumer_task is not None and not consumer_task.done(): + consumer_task.cancel() + if producer_task is not None and not producer_task.done(): + producer_task.cancel() + + # Await cancelled tasks to ensure proper cleanup + if consumer_task is None: + await asyncio.gather(producer_task, return_exceptions=True) + elif producer_task is None: + await asyncio.gather(consumer_task, return_exceptions=True) + else: + await asyncio.gather( + consumer_task, producer_task, return_exceptions=True + ) + + # Attempt rollback only if we inserted something + if all_ids: + try: + logger.warning("Performing rollback of file %s", file_id) + await vector_store.delete(ids=[file_id], executor=executor) + logger.info("Rollback completed for file %s", file_id) + except Exception as cleanup_error: + logger.error("Rollback failed for file %s: %s", file_id, cleanup_error) + + # Re-raise the original error + raise + + +async def _process_documents_batched_sync( + documents: List[Document], + file_id: str, + vector_store: Union["PgVector", "AtlasMongoVector"], + executor: "ThreadPoolExecutor", +) -> List[str]: + """ + Process documents in batches using synchronous vector store operations. + + Args: + documents: List of Document objects to process + file_id: Unique identifier for the file being processed + vector_store: Synchronous vector store instance (ExtendedPgVector or AtlasMongoVector) + executor: ThreadPoolExecutor for running sync operations + + Returns: + List of document IDs that were successfully inserted + """ + total_chunks = len(documents) + if total_chunks == 0: + return [] + + all_ids = [] + num_batches = calculate_num_batches(total_chunks, EMBEDDING_BATCH_SIZE) + + logger.info( + "Processing file %s with sync batching: %d batches of %d chunks each", + file_id, + num_batches, + EMBEDDING_BATCH_SIZE, + ) + + loop = asyncio.get_running_loop() + + for batch_idx in range(num_batches): + start_idx = batch_idx * EMBEDDING_BATCH_SIZE + end_idx = min(start_idx + EMBEDDING_BATCH_SIZE, total_chunks) + batch_documents = documents[start_idx:end_idx] + batch_ids = [file_id] * len(batch_documents) + + logger.info( + "Processing batch %d/%d: chunks %d-%d (%d chunks)", + batch_idx + 1, + num_batches, + start_idx, + end_idx - 1, + len(batch_documents), + ) + + try: + # Wrap sync call in executor to avoid blocking the event loop + batch_result_ids = await loop.run_in_executor( + executor, + lambda docs=batch_documents, ids=batch_ids: vector_store.add_documents( + docs, ids=ids + ), + ) + all_ids.extend(batch_result_ids) + + except Exception as batch_error: + logger.error("Batch %d failed: %s", batch_idx + 1, batch_error) + + # Rollback entire file from vector store + if ( + all_ids + ): # any batch succeeded (i.e., any chunks for this file were inserted) + logger.warning("Rolling back file %s due to batch failure", file_id) + try: + await loop.run_in_executor( + executor, lambda: vector_store.delete(ids=[file_id]) + ) + logger.info("Rollback completed for file %s", file_id) + except Exception as rollback_error: + logger.error( + "Rollback failed for file %s: %s", file_id, rollback_error + ) + + raise batch_error + + return all_ids + + +def generate_digest(page_content: str) -> str: + return hashlib.md5(page_content.encode("utf-8", "ignore")).hexdigest() + + +def _prepare_documents_sync( data: Iterable[Document], file_id: str, - user_id: str = "", - clean_content: bool = False, - executor=None, -) -> bool: + user_id: str, + clean_content: bool, +) -> List[Document]: + """ + Synchronous document preparation - runs in executor to avoid blocking event loop. + Handles text splitting, cleaning, and metadata preparation. + """ text_splitter = RecursiveCharacterTextSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP ) @@ -364,7 +659,7 @@ async def store_data_in_vector_db( doc.page_content = clean_text(doc.page_content) # Preparing documents with page content and metadata for insertion. - docs = [ + return [ Document( page_content=doc.page_content, metadata={ @@ -377,13 +672,48 @@ async def store_data_in_vector_db( for doc in documents ] + +async def store_data_in_vector_db( + data: Iterable[Document], + file_id: str, + user_id: str = "", + clean_content: bool = False, + executor=None, +) -> bool: + # Run document preparation in executor to avoid blocking the event loop + loop = asyncio.get_running_loop() + docs = await loop.run_in_executor( + executor, + _prepare_documents_sync, + data, + file_id, + user_id, + clean_content, + ) + try: - if isinstance(vector_store, AsyncPgVector): - ids = await vector_store.aadd_documents( - docs, ids=[file_id] * len(documents), executor=executor - ) + if EMBEDDING_BATCH_SIZE <= 0: + # synchronously embed the file and insert into vector store in one go + if isinstance(vector_store, AsyncPgVector): + ids = await vector_store.aadd_documents( + docs, ids=[file_id] * len(docs), executor=executor + ) + else: + ids = vector_store.add_documents(docs, ids=[file_id] * len(docs)) else: - ids = vector_store.add_documents(docs, ids=[file_id] * len(documents)) + # asynchronously embed the file and insert into vector store as it is embedding + # to lessen memory impact and speed up slightly as the majority of the document + # is inserted into db by the time it is fully embedded + + if isinstance(vector_store, AsyncPgVector): + ids = await _process_documents_async_pipeline( + docs, file_id, vector_store, executor + ) + else: + # Fallback to batched processing for sync vector stores + ids = await _process_documents_batched_sync( + docs, file_id, vector_store, executor + ) return {"message": "Documents added successfully", "ids": ids} @@ -402,8 +732,11 @@ async def store_data_in_vector_db( async def embed_local_file( document: StoreDocument, request: Request, entity_id: str = None ): - # Check if the file exists - if not os.path.exists(document.filepath): + file_path = validate_file_path(RAG_UPLOAD_DIR, document.filepath) + + # Check if the file exists and if it is within the allowed upload directory + if file_path is None or not os.path.exists(file_path): + logger.warning("Path validation failed for local embed: %s", document.filepath) raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.FILE_NOT_FOUND, @@ -414,14 +747,15 @@ async def embed_local_file( else: user_id = entity_id if entity_id else request.state.user.get("id") + loader = None try: loader, known_type, file_ext = get_loader( - document.filename, document.file_content_type, document.filepath + document.filename, document.file_content_type, file_path + ) + loop = asyncio.get_running_loop() + data = await loop.run_in_executor( + request.app.state.thread_pool, lambda: list(loader.lazy_load()) ) - data = await run_in_executor(request.app.state.thread_pool, loader.load) - - # Clean up temporary UTF-8 file if it was created for encoding conversion - cleanup_temp_encoding_file(loader) result = await store_data_in_vector_db( data, @@ -462,6 +796,10 @@ async def embed_local_file( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) + finally: + # Clean up temporary UTF-8 file if it was created for encoding conversion + if loader is not None: + cleanup_temp_encoding_file(loader) @router.post("/embed") @@ -476,17 +814,22 @@ async def embed_file( known_type = None user_id = get_user_id(request, entity_id) - temp_base_path = os.path.join(RAG_UPLOAD_DIR, user_id) - os.makedirs(temp_base_path, exist_ok=True) - temp_file_path = os.path.join(RAG_UPLOAD_DIR, user_id, file.filename) + validated_file_path = _make_unique_temp_path(user_id, file.filename) - await save_upload_file_async(file, temp_file_path) + if validated_file_path is None: + logger.warning("Path validation failed for embed: %s", file.filename) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Invalid request"), + ) try: + os.makedirs(os.path.dirname(validated_file_path), exist_ok=True) + await save_upload_file_async(file, validated_file_path) data, known_type, file_ext = await load_file_content( file.filename, file.content_type, - temp_file_path, + validated_file_path, request.app.state.thread_pool, ) @@ -537,7 +880,7 @@ async def embed_file( detail=f"Error during file processing: {str(e)}", ) finally: - await cleanup_temp_file_async(temp_file_path) + await cleanup_temp_file_async(validated_file_path) return { "status": response_status, @@ -604,15 +947,25 @@ async def embed_file_upload( entity_id: str = Form(None), ): user_id = get_user_id(request, entity_id) - temp_file_path = os.path.join(RAG_UPLOAD_DIR, uploaded_file.filename) - save_upload_file_sync(uploaded_file, temp_file_path) + validated_temp_file_path = _make_unique_temp_path(user_id, uploaded_file.filename) + + if validated_temp_file_path is None: + logger.warning( + "Path validation failed for embed-upload: %s", uploaded_file.filename + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Invalid request"), + ) try: + os.makedirs(os.path.dirname(validated_temp_file_path), exist_ok=True) + await save_upload_file_async(uploaded_file, validated_temp_file_path) data, known_type, file_ext = await load_file_content( uploaded_file.filename, uploaded_file.content_type, - temp_file_path, + validated_temp_file_path, request.app.state.thread_pool, ) @@ -648,7 +1001,7 @@ async def embed_file_upload( detail=f"Error during file processing: {str(e)}", ) finally: - os.remove(temp_file_path) + await cleanup_temp_file_async(validated_temp_file_path) return { "status": True, @@ -715,17 +1068,22 @@ async def extract_text_from_file( Returns the raw text content for text parsing purposes. """ user_id = get_user_id(request, entity_id) - temp_base_path = os.path.join(RAG_UPLOAD_DIR, user_id) - os.makedirs(temp_base_path, exist_ok=True) - temp_file_path = os.path.join(RAG_UPLOAD_DIR, user_id, file.filename) + validated_temp_file_path = _make_unique_temp_path(user_id, file.filename) - await save_upload_file_async(file, temp_file_path) + if validated_temp_file_path is None: + logger.warning("Path validation failed for text extraction: %s", file.filename) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Invalid request"), + ) try: + os.makedirs(os.path.dirname(validated_temp_file_path), exist_ok=True) + await save_upload_file_async(file, validated_temp_file_path) data, known_type, file_ext = await load_file_content( file.filename, file.content_type, - temp_file_path, + validated_temp_file_path, request.app.state.thread_pool, ) @@ -764,4 +1122,4 @@ async def extract_text_from_file( detail=f"Error during text extraction: {str(e)}", ) finally: - await cleanup_temp_file_async(temp_file_path) + await cleanup_temp_file_async(validated_temp_file_path) diff --git a/app/services/database.py b/app/services/database.py index 712e43c1..ae77edd3 100644 --- a/app/services/database.py +++ b/app/services/database.py @@ -20,6 +20,15 @@ async def close_pool(cls): async def ensure_vector_indexes(): + """Ensure required indexes on langchain_pg_embedding and migrate cmetadata to JSONB. + + Runs at startup. Idempotent — safe to call repeatedly. + Operations: + 1. B-tree index on custom_id. + 2. Expression index on (cmetadata->>'file_id'). + 3. DDL migration: JSON -> JSONB for cmetadata (skipped if already JSONB). + 4. GIN index (jsonb_path_ops) on cmetadata for containment queries. + """ table_name = "langchain_pg_embedding" column_name = "custom_id" # You might want to standardize the index naming convention @@ -33,13 +42,49 @@ async def ensure_vector_indexes(): """ ) + # Expression index for (cmetadata->>'file_id') text queries. + # NOTE: After the JSONB migration, LangChain generates @> containment + # queries served by ix_cmetadata_gin instead. Consider dropping this + # index in a follow-up once JSONB filtering is confirmed stable. await conn.execute( f""" - CREATE INDEX IF NOT EXISTS idx_{table_name}_file_id + CREATE INDEX IF NOT EXISTS idx_{table_name}_file_id ON {table_name} ((cmetadata->>'file_id')); """ ) + # Migrate cmetadata from JSON to JSONB (idempotent — skipped if already JSONB). + # Rollback: ALTER TABLE langchain_pg_embedding ALTER COLUMN cmetadata TYPE JSON USING cmetadata::json; + # NOTE: table name is hardcoded below (not interpolated) to avoid SQL injection. + await conn.execute( + """ + DO $$ + BEGIN + IF EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'langchain_pg_embedding' + AND table_schema = current_schema() + AND column_name = 'cmetadata' + AND data_type = 'json' + ) THEN + SET LOCAL lock_timeout = '10s'; + ALTER TABLE langchain_pg_embedding + ALTER COLUMN cmetadata TYPE JSONB USING cmetadata::jsonb; + END IF; + END + $$; + """ + ) + + # GIN index on cmetadata for efficient JSONB filtering + await conn.execute( + """ + CREATE INDEX IF NOT EXISTS ix_cmetadata_gin + ON langchain_pg_embedding + USING gin (cmetadata jsonb_path_ops); + """ + ) + logger.info("Vector database indexes ensured") diff --git a/app/services/mongo_client.py b/app/services/mongo_client.py index 8e5db46b..b1882be6 100644 --- a/app/services/mongo_client.py +++ b/app/services/mongo_client.py @@ -1,4 +1,5 @@ # app/services/mongo_client.py +import asyncio import logging from pymongo import MongoClient from pymongo.errors import PyMongoError @@ -6,11 +7,19 @@ logger = logging.getLogger(__name__) + async def mongo_health_check() -> bool: + client = None try: - client = MongoClient(ATLAS_MONGO_DB_URI) - client.admin.command("ping") + client = await asyncio.to_thread(MongoClient, ATLAS_MONGO_DB_URI) + await asyncio.to_thread(client.admin.command, "ping") return True except PyMongoError as e: - logger.error(f"MongoDB health check failed: {e}") - return False \ No newline at end of file + logger.error("MongoDB health check failed: %s", e) + return False + finally: + if client is not None: + try: + await asyncio.to_thread(client.close) + except Exception as e: + logger.debug("Failed to close health check client: %s", e) diff --git a/app/services/vector_store/async_pg_vector.py b/app/services/vector_store/async_pg_vector.py index dd1371d8..f4d8381a 100644 --- a/app/services/vector_store/async_pg_vector.py +++ b/app/services/vector_store/async_pg_vector.py @@ -1,75 +1,100 @@ -from typing import Optional, List, Tuple, Dict, Any +from typing import Callable, Optional, List, Tuple, Dict, Any, TypeVar import asyncio +from concurrent.futures import Executor from langchain_core.documents import Document -from langchain_core.runnables.config import run_in_executor from .extended_pg_vector import ExtendedPgVector +T = TypeVar("T") + + class AsyncPgVector(ExtendedPgVector): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._thread_pool = None - + def _get_thread_pool(self): if self._thread_pool is None: try: - # Try to get the thread pool from FastAPI app state - import contextvars - from fastapi import Request - # This is a fallback - in practice, we'll pass the executor explicitly loop = asyncio.get_running_loop() - self._thread_pool = getattr(loop, '_default_executor', None) - except: + self._thread_pool = getattr(loop, "_default_executor", None) + except Exception: pass return self._thread_pool - + + @staticmethod + async def _run_in_executor( + executor: Executor | None, + func: Callable[..., T], + *args: Any, + **kwargs: Any, + ) -> T: + """Run a sync callable in a thread pool executor. + + Wraps the call to convert StopIteration into RuntimeError. + StopIteration cannot be set on an asyncio.Future — it raises + TypeError and leaves the Future pending forever. + """ + + def wrapper() -> T: + try: + return func(*args, **kwargs) + except StopIteration as exc: + raise RuntimeError from exc + + loop = asyncio.get_running_loop() + return await loop.run_in_executor(executor, wrapper) + async def get_all_ids(self, executor=None) -> list[str]: executor = executor or self._get_thread_pool() - return await run_in_executor(executor, super().get_all_ids) - + return await self._run_in_executor(executor, super().get_all_ids) + async def get_filtered_ids(self, ids: list[str], executor=None) -> list[str]: executor = executor or self._get_thread_pool() - return await run_in_executor(executor, super().get_filtered_ids, ids) + return await self._run_in_executor(executor, super().get_filtered_ids, ids) - async def get_documents_by_ids(self, ids: list[str], executor=None) -> list[Document]: + async def get_documents_by_ids( + self, ids: list[str], executor=None + ) -> list[Document]: executor = executor or self._get_thread_pool() - return await run_in_executor(executor, super().get_documents_by_ids, ids) + return await self._run_in_executor(executor, super().get_documents_by_ids, ids) async def delete( - self, ids: Optional[list[str]] = None, collection_only: bool = False, executor=None + self, + ids: Optional[list[str]] = None, + collection_only: bool = False, + executor=None, ) -> None: executor = executor or self._get_thread_pool() - await run_in_executor(executor, self._delete_multiple, ids, collection_only) - + await self._run_in_executor( + executor, self._delete_multiple, ids, collection_only + ) + async def asimilarity_search_with_score_by_vector( - self, - embedding: List[float], - k: int = 4, + self, + embedding: List[float], + k: int = 4, filter: Optional[Dict[str, Any]] = None, - executor=None + executor=None, ) -> List[Tuple[Document, float]]: """Async version of similarity_search_with_score_by_vector""" executor = executor or self._get_thread_pool() - return await run_in_executor( - executor, - super().similarity_search_with_score_by_vector, - embedding, - k, - filter + return await self._run_in_executor( + executor, + super().similarity_search_with_score_by_vector, + embedding, + k, + filter, ) - + async def aadd_documents( - self, - documents: List[Document], + self, + documents: List[Document], ids: Optional[List[str]] = None, executor=None, - **kwargs + **kwargs, ) -> List[str]: """Async version of add_documents""" executor = executor or self._get_thread_pool() - return await run_in_executor( - executor, - super().add_documents, - documents, - ids=ids, - **kwargs - ) \ No newline at end of file + return await self._run_in_executor( + executor, super().add_documents, documents, ids=ids, **kwargs + ) diff --git a/app/services/vector_store/atlas_mongo_vector.py b/app/services/vector_store/atlas_mongo_vector.py index 3fa97074..dc94a922 100644 --- a/app/services/vector_store/atlas_mongo_vector.py +++ b/app/services/vector_store/atlas_mongo_vector.py @@ -1,20 +1,33 @@ import copy +import hashlib from typing import Any, List, Optional, Tuple from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_mongodb import MongoDBAtlasVectorSearch + class AtlasMongoVector(MongoDBAtlasVectorSearch): @property def embedding_function(self) -> Embeddings: return self.embeddings - def add_documents(self, docs: list[Document], ids: list[str]): - # {file_id}_{idx} - new_ids = [id for id in range(len(ids))] - file_id = docs[0].metadata['file_id'] - f_ids = [f'{file_id}_{id}' for id in new_ids] - return super().add_documents(docs, f_ids) + def add_documents( + self, + documents: List[Document], + ids: Optional[List[str]] = None, + **kwargs, + ) -> List[str]: + """Caller-supplied ``ids`` are intentionally ignored; IDs are derived from + each document's content digest to ensure cross-batch uniqueness within a file. + """ + if not documents: + return [] + file_id = documents[0].metadata["file_id"] + f_ids = [ + f"{file_id}_{doc.metadata.get('digest') or hashlib.md5(doc.page_content.encode()).hexdigest()}" + for doc in documents + ] + return super().add_documents(documents, f_ids) def similarity_search_with_score_by_vector( self, @@ -44,7 +57,7 @@ def similarity_search_with_score_by_vector( def get_all_ids(self) -> list[str]: # Return unique file_id fields in self._collection return self._collection.distinct("file_id") - + def get_filtered_ids(self, ids: list[str]) -> list[str]: # Return unique file_id fields filtered by the provided ids return self._collection.distinct("file_id", {"file_id": {"$in": ids}}) @@ -68,4 +81,4 @@ def get_documents_by_ids(self, ids: list[str]) -> list[Document]: def delete(self, ids: Optional[list[str]] = None) -> None: # Delete documents by file_id if ids is not None: - self._collection.delete_many({"file_id": {"$in": ids}}) \ No newline at end of file + self._collection.delete_many({"file_id": {"$in": ids}}) diff --git a/app/services/vector_store/factory.py b/app/services/vector_store/factory.py index 17ca3e55..854cc78a 100644 --- a/app/services/vector_store/factory.py +++ b/app/services/vector_store/factory.py @@ -1,4 +1,6 @@ -from typing import Optional +import logging +from typing import Any, Optional + from pymongo import MongoClient from langchain_core.embeddings import Embeddings @@ -6,31 +8,79 @@ from .atlas_mongo_vector import AtlasMongoVector from .extended_pg_vector import ExtendedPgVector +logger = logging.getLogger(__name__) + +# Holds the MongoClient so it can be closed on shutdown. +_mongo_client: Optional[MongoClient] = None + def get_vector_store( connection_string: str, embeddings: Embeddings, collection_name: str, mode: str = "sync", - search_index: Optional[str] = None + search_index: Optional[str] = None, ): + """Create a vector store instance for the given mode. + + Note: For 'atlas-mongo' mode, the MongoClient is stored at module level + so it can be closed on shutdown via close_vector_store_connections(). + """ + global _mongo_client + if mode == "sync": return ExtendedPgVector( connection_string=connection_string, embedding_function=embeddings, collection_name=collection_name, + use_jsonb=True, ) elif mode == "async": return AsyncPgVector( connection_string=connection_string, embedding_function=embeddings, collection_name=collection_name, + use_jsonb=True, ) elif mode == "atlas-mongo": - mongo_db = MongoClient(connection_string).get_database() + if _mongo_client is not None: + _mongo_client.close() + _mongo_client = MongoClient(connection_string) + mongo_db = _mongo_client.get_database() mong_collection = mongo_db[collection_name] return AtlasMongoVector( collection=mong_collection, embedding=embeddings, index_name=search_index ) else: - raise ValueError("Invalid mode specified. Choose 'sync', 'async', or 'atlas-mongo'.") \ No newline at end of file + raise ValueError( + "Invalid mode specified. Choose 'sync', 'async', or 'atlas-mongo'." + ) + + +def close_vector_store_connections(vector_store: Any) -> None: + """Close connections held by the vector store and its backing clients. + + Closes the module-level MongoClient (if atlas-mongo mode was used) and + disposes the SQLAlchemy engine on the vector store (if pgvector mode). + Safe to call multiple times. + """ + global _mongo_client + + # Close MongoDB client if one was created + if _mongo_client is not None: + try: + _mongo_client.close() + logger.info("MongoDB client closed") + except Exception as e: + logger.warning("Failed to close MongoDB client: %s", e) + finally: + _mongo_client = None + + # Dispose SQLAlchemy engine if the vector store has one + engine = getattr(vector_store, "_bind", None) + if engine is not None and hasattr(engine, "dispose"): + try: + engine.dispose() + logger.info("SQLAlchemy engine disposed") + except Exception as e: + logger.warning("Failed to dispose SQLAlchemy engine: %s", e) diff --git a/app/utils/document_loader.py b/app/utils/document_loader.py index 5c2f273d..900b1561 100644 --- a/app/utils/document_loader.py +++ b/app/utils/document_loader.py @@ -4,7 +4,7 @@ import codecs import tempfile -from typing import List, Optional +from typing import Iterator, List, Optional import chardet from langchain_core.documents import Document @@ -82,29 +82,27 @@ def get_loader(filename: str, file_content_type: str, filepath: str): encoding = detect_file_encoding(filepath) if encoding != "utf-8": - # For non-UTF-8 encodings, we need to convert the file first - # Create a temporary UTF-8 file + # For non-UTF-8 encodings, convert to UTF-8 using streaming + # to avoid holding the entire file in memory as a single string temp_file = None try: with tempfile.NamedTemporaryFile( mode="w", encoding="utf-8", suffix=".csv", delete=False ) as temp_file: - # Read the original file with detected encoding with open( filepath, "r", encoding=encoding, errors="replace" ) as original_file: - content = original_file.read() - temp_file.write(content) + while True: + chunk = original_file.read(64 * 1024) + if not chunk: + break + temp_file.write(chunk) temp_filepath = temp_file.name - # Use the temporary UTF-8 file with CSVLoader loader = CSVLoader(temp_filepath) - - # Store the temp file path for cleanup loader._temp_filepath = temp_filepath except Exception as e: - # If temp file was created but there was an error, clean it up if temp_file and os.path.exists(temp_file.name): os.unlink(temp_file.name) raise e @@ -235,19 +233,32 @@ def __init__(self, filepath: str, extract_images: bool = False): self.extract_images = extract_images self._temp_filepath = None # For compatibility with cleanup function - def load(self) -> List[Document]: - """Load PDF documents with automatic fallback on image extraction errors.""" + def lazy_load(self) -> Iterator[Document]: + """Lazy load PDF documents with automatic fallback on image extraction errors.""" loader = PyPDFLoader(self.filepath, extract_images=self.extract_images) + if not self.extract_images: + # No image extraction: no fallback needed, stream directly + yield from loader.lazy_load() + return + + # extract_images=True: must collect eagerly so that a mid-stream + # KeyError doesn't leave already-yielded pages duplicated by the + # fallback (yield from + try/except would deliver partial + full). try: - return loader.load() + pages = list(loader.lazy_load()) except KeyError as e: - if "/Filter" in str(e) and self.extract_images: + if "/Filter" in str(e): logger.warning( f"PDF image extraction failed for {self.filepath}, falling back to text-only: {e}" ) fallback_loader = PyPDFLoader(self.filepath, extract_images=False) - return fallback_loader.load() + pages = list(fallback_loader.lazy_load()) else: # Re-raise if it's a different error raise + yield from pages + + def load(self) -> List[Document]: + """Load PDF documents with automatic fallback on image extraction errors.""" + return list(self.lazy_load()) diff --git a/app/utils/health.py b/app/utils/health.py index 95178809..78fd6f78 100644 --- a/app/utils/health.py +++ b/app/utils/health.py @@ -4,10 +4,10 @@ from app.services.mongo_client import mongo_health_check -def is_health_ok(): +async def is_health_ok(): if VECTOR_DB_TYPE == VectorDBType.PGVECTOR: - return pg_health_check() + return await pg_health_check() if VECTOR_DB_TYPE == VectorDBType.ATLAS_MONGO: - return mongo_health_check() + return await mongo_health_check() else: - return True \ No newline at end of file + return True diff --git a/docker-compose.yaml b/docker-compose.yaml index 225299ef..45f59cc8 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -15,6 +15,7 @@ services: environment: - DB_HOST=db - DB_PORT=5432 + - EMBEDDING_BATCH_SIZE=${EMBEDDING_BATCH_SIZE:-500} ports: - "8000:8000" volumes: diff --git a/main.py b/main.py index 8af99f97..e300951a 100644 --- a/main.py +++ b/main.py @@ -20,10 +20,12 @@ VECTOR_DB_TYPE, LogMiddleware, logger, + vector_store, ) from app.middleware import security_middleware from app.routes import document_routes, pgvector_routes from app.services.database import PSQLDatabase, ensure_vector_indexes +from app.services.vector_store.factory import close_vector_store_connections @asynccontextmanager @@ -47,10 +49,25 @@ async def lifespan(app: FastAPI): yield # Cleanup logic + if VECTOR_DB_TYPE == VectorDBType.PGVECTOR: + try: + logger.info("Closing asyncpg connection pool") + await PSQLDatabase.close_pool() + logger.info("asyncpg connection pool closed") + except Exception as e: + logger.warning("Failed to close asyncpg pool: %s", e) + + # Drain in-flight work before closing backing resources logger.info("Shutting down thread pool") app.state.thread_pool.shutdown(wait=True) logger.info("Thread pool shutdown complete") + # Close vector store connections (MongoDB client / SQLAlchemy engine) + try: + close_vector_store_connections(vector_store) + except Exception as e: + logger.warning("Failed to close vector store connections: %s", e) + app = FastAPI(lifespan=lifespan, debug=debug_mode) @@ -79,17 +96,10 @@ async def lifespan(app: FastAPI): @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): - body = await request.body() - logger.debug(f"Validation error occurred") - logger.debug(f"Raw request body: {body.decode()}") - logger.debug(f"Validation errors: {exc.errors()}") + logger.debug("Validation error: %s", exc.errors()) return JSONResponse( status_code=422, - content={ - "detail": exc.errors(), - "body": body.decode(), - "message": "Request validation failed", - }, + content={"detail": exc.errors(), "message": "Request validation failed"}, ) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..384296e7 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,23 @@ +[pytest] +# Test discovery +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Async mode for pytest-asyncio +asyncio_mode = auto + +# Markers +markers = + integration: marks tests as integration tests (may require more resources) + slow: marks tests as slow running + +# Default options +addopts = -v --tb=short + +# Filter warnings +filterwarnings = + ignore::DeprecationWarning + ignore::PendingDeprecationWarning + diff --git a/requirements.lite.txt b/requirements.lite.txt index 661be35a..b3454c06 100644 --- a/requirements.lite.txt +++ b/requirements.lite.txt @@ -1,36 +1,37 @@ -langchain==0.3.26 -langchain-community==0.3.27 -langchain-openai==0.3.27 -langchain-core==0.3.75 -langchain-google-genai==2.1.10 -langchain-google-vertexai==2.0.27 +langchain==1.2.10 +langchain-community==0.4.1 +langchain-openai==1.1.10 +langchain-core==1.2.16 +langchain-google-genai==4.2.0 sqlalchemy==2.0.41 python-dotenv==1.1.1 fastapi==0.115.12 psycopg2-binary==2.9.9 pgvector==0.2.5 uvicorn==0.28.0 -pypdf==6.0.0 -unstructured==0.16.11 +pypdf==6.9.1 +unstructured==0.18.32 markdown==3.8.2 networkx==3.2.1 pandas==2.2.1 openpyxl==3.1.5 docx2txt==0.9 pypandoc==1.15 -PyJWT==2.8.0 +PyJWT==2.12.1 asyncpg==0.29.0 -python-multipart==0.0.19 +python-multipart==0.0.22 aiofiles==24.1.0 rapidocr-onnxruntime==1.4.4 opencv-python-headless==4.9.0.80 -pymongo==4.6.3 -langchain-mongodb==0.2.0 -cryptography==45.0.5 +pymongo>=4.12.0,<5 +langchain-mongodb==0.11.0 +cryptography==46.0.5 python-magic==0.4.27 python-pptx==1.0.2 xlrd==2.0.2 -langchain-aws==0.2.1 -boto3==1.34.144 +langchain-aws==1.3.1 +boto3>=1.42.42,<2 chardet==5.2.0 -langchain-ollama==0.3.3 \ No newline at end of file +langchain-ollama==1.0.1 +tenacity>=9.0.0 +msoffcrypto-tool>=6.0.0,<7 diff --git a/requirements.txt b/requirements.txt index 0f811572..18cfaf06 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,40 +1,41 @@ -langchain==0.3.26 -langchain-community==0.3.27 -langchain-openai==0.3.27 -langchain-core==0.3.75 -langchain-aws==0.2.1 -langchain-google-vertexai==2.0.27 -langchain_text_splitters==0.3.8 # 0.3.3 -boto3==1.34.144 +langchain==1.2.10 +langchain-community==0.4.1 +langchain-openai==1.1.10 +langchain-core==1.2.16 +langchain-aws==1.3.1 +langchain-text-splitters==1.1.1 +boto3>=1.42.42,<2 sqlalchemy==2.0.41 python-dotenv==1.1.1 fastapi==0.115.12 psycopg2-binary==2.9.9 pgvector==0.2.5 uvicorn==0.28.0 -pypdf==6.0.0 -unstructured==0.16.11 +pypdf==6.9.1 +unstructured==0.18.32 markdown==3.8.2 networkx==3.2.1 pandas==2.2.1 openpyxl==3.1.5 docx2txt==0.9 pypandoc==1.15 -PyJWT==2.8.0 +PyJWT==2.12.1 asyncpg==0.29.0 -python-multipart==0.0.19 +python-multipart==0.0.22 sentence_transformers==3.1.1 aiofiles==24.1.0 rapidocr-onnxruntime==1.4.4 opencv-python-headless==4.9.0.80 -pymongo==4.6.3 -langchain-mongodb==0.2.0 -langchain-ollama==0.3.3 -langchain-huggingface==0.1.0 -langchain-google-genai==2.1.10 -cryptography==45.0.5 +pymongo>=4.12.0,<5 +langchain-mongodb==0.11.0 +langchain-ollama==1.0.1 +langchain-huggingface==1.2.0 +langchain-google-genai==4.2.0 +cryptography==46.0.5 python-magic==0.4.27 python-pptx==1.0.2 xlrd==2.0.2 -pydantic==2.9.2 +pydantic>=2.10.6,<3 chardet==5.2.0 +tenacity>=9.0.0 +msoffcrypto-tool>=6.0.0,<7 diff --git a/setup_venv.sh b/setup_venv.sh new file mode 100755 index 00000000..3b88d40d --- /dev/null +++ b/setup_venv.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +VENV_DIR="$SCRIPT_DIR/venv" +REQ_FILE="${1:-requirements.lite.txt}" + +if [[ ! -f "$SCRIPT_DIR/$REQ_FILE" ]]; then + echo "Error: $SCRIPT_DIR/$REQ_FILE not found" + exit 1 +fi + +echo "==> Removing old venv (if any)..." +rm -rf "$VENV_DIR" + +echo "==> Creating fresh venv..." +python3 -m venv "$VENV_DIR" + +echo "==> Upgrading pip..." +"$VENV_DIR/bin/pip" install --upgrade pip --quiet + +echo "==> Installing from $REQ_FILE..." +"$VENV_DIR/bin/pip" install -r "$SCRIPT_DIR/$REQ_FILE" + +echo "" +echo "Done! Activate with:" +echo " source $VENV_DIR/bin/activate" diff --git a/tests/conftest.py b/tests/conftest.py index 9ec5ada9..f03928cf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,19 +14,22 @@ # Do this *before* importing any app modules. from langchain_community.vectorstores.pgvector import PGVector + def dummy_post_init(self): # Skip extension creation pass + AsyncPgVector.__post_init__ = dummy_post_init PGVector.__post_init__ = dummy_post_init from langchain_core.documents import Document + class DummyVectorStore: def get_all_ids(self) -> list[str]: return ["testid1", "testid2"] - + def get_filtered_ids(self, ids) -> list[str]: dummy_ids = ["testid1", "testid2"] return [id for id in dummy_ids if id in ids] @@ -40,14 +43,17 @@ async def get_documents_by_ids(self, ids: list[str]) -> list[Document]: def similarity_search_with_score_by_vector(self, embedding, k: int, filter: dict): doc = Document( page_content="Queried content", - metadata={"file_id": filter.get("file_id", "testid1"), "user_id": "testuser"}, + metadata={ + "file_id": filter.get("file_id", "testid1"), + "user_id": "testuser", + }, ) return [(doc, 0.9)] - def add_documents(self, docs, ids): + def add_documents(self, documents, ids=None, **kwargs): return ids - async def aadd_documents(self, docs, ids): + async def aadd_documents(self, documents, ids=None, **kwargs): return ids async def delete(self, ids=None, collection_only: bool = False): diff --git a/tests/services/test_async_pg_vector.py b/tests/services/test_async_pg_vector.py new file mode 100644 index 00000000..cb280591 --- /dev/null +++ b/tests/services/test_async_pg_vector.py @@ -0,0 +1,95 @@ +import asyncio +from unittest.mock import patch, MagicMock +import pytest +from langchain_core.documents import Document +from app.services.vector_store.async_pg_vector import AsyncPgVector +from app.services.vector_store.extended_pg_vector import ExtendedPgVector + + +class DummyAsyncPgVector(AsyncPgVector): + """Subclass that skips DB initialization.""" + + def __init__(self): + # Bypass ExtendedPgVector/PGVector __init__ entirely + self._thread_pool = None + self._bind = None # Prevent AttributeError in PGVector.__del__ + + +@pytest.fixture +def store(): + return DummyAsyncPgVector() + + +@pytest.mark.asyncio +async def test_get_all_ids_dispatches_to_super(store): + with patch.object( + ExtendedPgVector, "get_all_ids", return_value=["id1", "id2"] + ) as mock: + result = await store.get_all_ids() + mock.assert_called_once_with() + assert result == ["id1", "id2"] + + +@pytest.mark.asyncio +async def test_get_filtered_ids_passes_ids(store): + with patch.object( + ExtendedPgVector, "get_filtered_ids", return_value=["id1"] + ) as mock: + result = await store.get_filtered_ids(["id1", "id2"]) + mock.assert_called_once_with(["id1", "id2"]) + assert result == ["id1"] + + +@pytest.mark.asyncio +async def test_get_documents_by_ids_passes_ids(store): + docs = [Document(page_content="test", metadata={"file_id": "id1"})] + with patch.object( + ExtendedPgVector, "get_documents_by_ids", return_value=docs + ) as mock: + result = await store.get_documents_by_ids(["id1"]) + mock.assert_called_once_with(["id1"]) + assert result == docs + + +@pytest.mark.asyncio +async def test_delete_passes_args(store): + with patch.object(ExtendedPgVector, "_delete_multiple") as mock: + await store.delete(ids=["id1"], collection_only=True) + mock.assert_called_once_with(["id1"], True) + + +@pytest.mark.asyncio +async def test_asimilarity_search_passes_args(store): + expected = [(Document(page_content="test", metadata={}), 0.9)] + with patch.object( + ExtendedPgVector, + "similarity_search_with_score_by_vector", + return_value=expected, + ) as mock: + embedding = [0.1, 0.2, 0.3] + result = await store.asimilarity_search_with_score_by_vector( + embedding, k=5, filter={"file_id": {"$eq": "id1"}} + ) + mock.assert_called_once_with(embedding, 5, {"file_id": {"$eq": "id1"}}) + assert result == expected + + +@pytest.mark.asyncio +async def test_aadd_documents_passes_args(store): + docs = [Document(page_content="test", metadata={})] + with patch.object(ExtendedPgVector, "add_documents", return_value=["id1"]) as mock: + result = await store.aadd_documents(docs, ids=["id1"]) + mock.assert_called_once_with(docs, ids=["id1"]) + assert result == ["id1"] + + +@pytest.mark.asyncio +async def test_run_in_executor_converts_stop_iteration(store): + """StopIteration can't be set on an asyncio.Future — verify it becomes RuntimeError.""" + + def raises_stop(): + raise StopIteration("exhausted") + + with patch.object(ExtendedPgVector, "get_all_ids", side_effect=raises_stop): + with pytest.raises(RuntimeError): + await store.get_all_ids() diff --git a/tests/services/test_database.py b/tests/services/test_database.py index b26b10e5..18570482 100644 --- a/tests/services/test_database.py +++ b/tests/services/test_database.py @@ -1,49 +1,86 @@ +import asyncio + import pytest from app.services.database import ensure_vector_indexes, PSQLDatabase -# Create dummy classes to simulate a database connection and pool -class DummyConnection: +class CapturingConnection: + """Records every SQL statement passed to execute().""" + + def __init__(self): + self.statements = [] + async def fetchval(self, query, index_name): - # Simulate that the index does not exist return False async def execute(self, query): + self.statements.append(query) return "Executed" -class DummyAcquire: +class CapturingAcquire: + def __init__(self, conn): + self._conn = conn + async def __aenter__(self): - return DummyConnection() + return self._conn async def __aexit__(self, exc_type, exc, tb): pass -class DummyPool: +class CapturingPool: + def __init__(self, conn): + self._conn = conn + def acquire(self): - return DummyAcquire() + return CapturingAcquire(self._conn) -class DummyDatabase: - pool = DummyPool() +def _run_with_captured_conn(monkeypatch): + """Run ensure_vector_indexes() and return the captured connection.""" + conn = CapturingConnection() + pool = CapturingPool(conn) - @classmethod - async def get_pool(cls): - return cls.pool + async def fake_get_pool(): + return pool + monkeypatch.setattr(PSQLDatabase, "get_pool", fake_get_pool) + asyncio.run(ensure_vector_indexes()) + return conn -@pytest.fixture -def dummy_pool(monkeypatch): - monkeypatch.setattr(PSQLDatabase, "get_pool", DummyDatabase.get_pool) - return DummyPool() +def test_ensure_vector_indexes(monkeypatch): + conn = _run_with_captured_conn(monkeypatch) + assert len(conn.statements) > 0 -import asyncio + +def test_ensure_vector_indexes_do_block_dollar_quoting(monkeypatch): + """DO block must use $$ dollar-quoting, not single $.""" + conn = _run_with_captured_conn(monkeypatch) + do_block = next(s for s in conn.statements if "DO" in s) + assert "$$" in do_block, "DO block must use $$ dollar-quoting" + + +def test_ensure_vector_indexes_jsonb_migration_sql(monkeypatch): + """Migration block contains the correct ALTER COLUMN and schema filter.""" + conn = _run_with_captured_conn(monkeypatch) + do_block = next(s for s in conn.statements if "DO" in s) + assert "TYPE JSONB" in do_block + assert "cmetadata::jsonb" in do_block + assert "table_schema = current_schema()" in do_block + + +def test_ensure_vector_indexes_lock_timeout(monkeypatch): + """Migration sets a lock_timeout before ALTER TABLE.""" + conn = _run_with_captured_conn(monkeypatch) + do_block = next(s for s in conn.statements if "DO" in s) + assert "lock_timeout" in do_block -@pytest.mark.asyncio -async def test_ensure_vector_indexes(monkeypatch, dummy_pool): - result = await ensure_vector_indexes() - # If no exceptions are raised, the function worked as expected. - assert result is None +def test_ensure_vector_indexes_gin_index(monkeypatch): + """GIN index with jsonb_path_ops is created.""" + conn = _run_with_captured_conn(monkeypatch) + gin_stmt = next(s for s in conn.statements if "ix_cmetadata_gin" in s) + assert "jsonb_path_ops" in gin_stmt + assert "USING gin" in gin_stmt diff --git a/tests/services/test_vector_store_factory.py b/tests/services/test_vector_store_factory.py new file mode 100644 index 00000000..d1139318 --- /dev/null +++ b/tests/services/test_vector_store_factory.py @@ -0,0 +1,135 @@ +"""Tests for vector store factory shutdown and cleanup logic.""" + +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from app.services.vector_store import factory +from app.services.vector_store.factory import close_vector_store_connections + + +def test_close_vector_store_connections_mongo(): + """close_vector_store_connections closes the module-level MongoClient.""" + mock_client = MagicMock() + factory._mongo_client = mock_client + + try: + close_vector_store_connections(vector_store=None) + mock_client.close.assert_called_once() + assert factory._mongo_client is None + finally: + factory._mongo_client = None + + +def test_close_vector_store_connections_sqlalchemy(): + """close_vector_store_connections disposes the SQLAlchemy engine on the vector store.""" + mock_engine = MagicMock() + mock_engine.dispose = MagicMock() + + mock_vs = MagicMock() + mock_vs._bind = mock_engine + + close_vector_store_connections(mock_vs) + mock_engine.dispose.assert_called_once() + + +def test_close_vector_store_connections_idempotent(): + """Calling close_vector_store_connections twice is safe.""" + mock_client = MagicMock() + factory._mongo_client = mock_client + + mock_engine = MagicMock() + mock_vs = MagicMock() + mock_vs._bind = mock_engine + + try: + close_vector_store_connections(mock_vs) + close_vector_store_connections(mock_vs) + + # Mongo closed once, then global set to None so second call skips it + mock_client.close.assert_called_once() + # Engine dispose called twice (harmless — SQLAlchemy handles it) + assert mock_engine.dispose.call_count == 2 + finally: + factory._mongo_client = None + + +def test_close_vector_store_connections_no_bind(): + """close_vector_store_connections handles vector stores without _bind.""" + mock_vs = MagicMock(spec=[]) # No attributes at all + # Should not raise + close_vector_store_connections(mock_vs) + + +def test_close_vector_store_connections_none(): + """close_vector_store_connections handles None vector store.""" + close_vector_store_connections(None) + + +def test_get_vector_store_atlas_mongo_closes_previous_client(): + """Calling get_vector_store(atlas-mongo) twice closes the first MongoClient.""" + factory._mongo_client = None + + with patch("app.services.vector_store.factory.MongoClient") as MockMC: + mock_client_1 = MagicMock() + mock_client_2 = MagicMock() + MockMC.side_effect = [mock_client_1, mock_client_2] + + mock_embeddings = MagicMock() + + with patch("app.services.vector_store.factory.AtlasMongoVector"): + factory.get_vector_store( + "conn1", mock_embeddings, "coll", mode="atlas-mongo", search_index="idx" + ) + assert factory._mongo_client is mock_client_1 + mock_client_1.close.assert_not_called() + + factory.get_vector_store( + "conn2", mock_embeddings, "coll", mode="atlas-mongo", search_index="idx" + ) + # First client should have been closed before overwrite + mock_client_1.close.assert_called_once() + assert factory._mongo_client is mock_client_2 + + factory._mongo_client = None + + +def test_get_vector_store_sync_passes_use_jsonb(): + """Sync PgVector must be instantiated with use_jsonb=True.""" + with patch("app.services.vector_store.factory.ExtendedPgVector") as MockPG: + mock_embeddings = MagicMock() + factory.get_vector_store("conn", mock_embeddings, "coll", mode="sync") + _, kwargs = MockPG.call_args + assert kwargs.get("use_jsonb") is True + + +def test_get_vector_store_async_passes_use_jsonb(): + """Async PgVector must be instantiated with use_jsonb=True.""" + with patch("app.services.vector_store.factory.AsyncPgVector") as MockPG: + mock_embeddings = MagicMock() + factory.get_vector_store("conn", mock_embeddings, "coll", mode="async") + _, kwargs = MockPG.call_args + assert kwargs.get("use_jsonb") is True + + +def test_load_file_content_cleans_up_on_lazy_load_failure(): + """cleanup_temp_encoding_file is called even when lazy_load() raises.""" + from app.routes.document_routes import load_file_content + + mock_loader = MagicMock() + mock_loader._temp_filepath = "/tmp/fake.csv" + mock_loader.lazy_load.side_effect = RuntimeError("disk error") + + with patch( + "app.routes.document_routes.get_loader", + return_value=(mock_loader, True, "csv"), + ): + with patch( + "app.routes.document_routes.cleanup_temp_encoding_file" + ) as mock_cleanup: + with pytest.raises(RuntimeError, match="disk error"): + asyncio.run( + load_file_content("f.csv", "text/csv", "/fake/path", executor=None) + ) + mock_cleanup.assert_called_once_with(mock_loader) diff --git a/tests/test_batch_processing.py b/tests/test_batch_processing.py new file mode 100644 index 00000000..c187ef22 --- /dev/null +++ b/tests/test_batch_processing.py @@ -0,0 +1,574 @@ +# tests/test_batch_processing.py +import pytest +from unittest.mock import Mock, AsyncMock, patch, MagicMock +from langchain_core.documents import Document + + +class TestBatchProcessing: + """Test batch processing functions.""" + + @pytest.fixture + def mock_documents(self): + """Create mock documents for testing.""" + return [ + Document(page_content=f"content_{i}", metadata={"file_id": "test_file"}) + for i in range(10) + ] + + @pytest.fixture + def mock_async_vector_store(self): + """Create mock async vector store.""" + store = AsyncMock() + store.aadd_documents = AsyncMock(return_value=["id1", "id2"]) + store.delete = AsyncMock() + return store + + @pytest.fixture + def mock_sync_vector_store(self): + """Create mock sync vector store.""" + store = Mock() + store.add_documents = Mock(return_value=["id1", "id2"]) + store.delete = Mock() + return store + + # --- Async Pipeline Tests --- + + @pytest.mark.asyncio + async def test_async_pipeline_basic(self, mock_documents, mock_async_vector_store): + """Test basic async pipeline processing.""" + from app.routes.document_routes import _process_documents_async_pipeline + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 3): + result = await _process_documents_async_pipeline( + documents=mock_documents, + file_id="test_file", + vector_store=mock_async_vector_store, + executor=None, + ) + + assert len(result) > 0 + assert mock_async_vector_store.aadd_documents.called + + @pytest.mark.asyncio + async def test_async_pipeline_single_batch(self, mock_async_vector_store): + """Test when all documents fit in one batch.""" + from app.routes.document_routes import _process_documents_async_pipeline + + docs = [Document(page_content="test", metadata={})] + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 10): + result = await _process_documents_async_pipeline( + documents=docs, + file_id="test_file", + vector_store=mock_async_vector_store, + executor=None, + ) + + assert mock_async_vector_store.aadd_documents.call_count == 1 + + @pytest.mark.asyncio + async def test_async_pipeline_exact_batch_size(self, mock_async_vector_store): + """Test when document count equals batch size.""" + from app.routes.document_routes import _process_documents_async_pipeline + + docs = [Document(page_content=f"test_{i}", metadata={}) for i in range(5)] + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 5): + result = await _process_documents_async_pipeline( + documents=docs, + file_id="test_file", + vector_store=mock_async_vector_store, + executor=None, + ) + + assert mock_async_vector_store.aadd_documents.call_count == 1 + + @pytest.mark.asyncio + async def test_async_pipeline_empty_documents(self, mock_async_vector_store): + """Test with empty document list.""" + from app.routes.document_routes import _process_documents_async_pipeline + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 3): + result = await _process_documents_async_pipeline( + documents=[], + file_id="test_file", + vector_store=mock_async_vector_store, + executor=None, + ) + + assert result == [] + assert not mock_async_vector_store.aadd_documents.called + + @pytest.mark.asyncio + async def test_async_pipeline_rollback_on_error( + self, mock_documents, mock_async_vector_store + ): + """Test that rollback occurs when insertion fails after some success.""" + from app.routes.document_routes import _process_documents_async_pipeline + + # First batch succeeds, second batch fails + mock_async_vector_store.aadd_documents = AsyncMock( + side_effect=[["id1"], Exception("DB error")] + ) + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 3): + with pytest.raises(Exception, match="DB error"): + await _process_documents_async_pipeline( + documents=mock_documents, + file_id="test_file", + vector_store=mock_async_vector_store, + executor=None, + ) + + mock_async_vector_store.delete.assert_called_once() + + @pytest.mark.asyncio + async def test_async_pipeline_no_rollback_on_first_batch_error( + self, mock_documents, mock_async_vector_store + ): + """Test that no rollback occurs if first batch fails (nothing inserted).""" + from app.routes.document_routes import _process_documents_async_pipeline + + mock_async_vector_store.aadd_documents = AsyncMock( + side_effect=Exception("DB error") + ) + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 3): + with pytest.raises(Exception): + await _process_documents_async_pipeline( + documents=mock_documents, + file_id="test_file", + vector_store=mock_async_vector_store, + executor=None, + ) + + # Should not attempt rollback since nothing was inserted + assert not mock_async_vector_store.delete.called + + # --- Sync Batched Tests --- + + @pytest.mark.asyncio + async def test_sync_batched_basic(self, mock_documents, mock_sync_vector_store): + """Test basic sync batch processing.""" + from app.routes.document_routes import _process_documents_batched_sync + import asyncio + + # Create a real executor for the test + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=2) as executor: + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 3): + result = await _process_documents_batched_sync( + documents=mock_documents, + file_id="test_file", + vector_store=mock_sync_vector_store, + executor=executor, + ) + + assert len(result) > 0 + assert mock_sync_vector_store.add_documents.called + + @pytest.mark.asyncio + async def test_sync_batched_empty_documents(self, mock_sync_vector_store): + """Test sync batch processing with empty documents.""" + from app.routes.document_routes import _process_documents_batched_sync + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 3): + result = await _process_documents_batched_sync( + documents=[], + file_id="test_file", + vector_store=mock_sync_vector_store, + executor=None, + ) + + assert result == [] + assert not mock_sync_vector_store.add_documents.called + + @pytest.mark.asyncio + async def test_sync_batched_rollback_on_error( + self, mock_documents, mock_sync_vector_store + ): + """Test sync rollback behavior.""" + from app.routes.document_routes import _process_documents_batched_sync + from concurrent.futures import ThreadPoolExecutor + + # First batch succeeds, second batch fails + mock_sync_vector_store.add_documents = Mock( + side_effect=[["id1"], Exception("DB error")] + ) + + with ThreadPoolExecutor(max_workers=2) as executor: + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 3): + with pytest.raises(Exception): + await _process_documents_batched_sync( + documents=mock_documents, + file_id="test_file", + vector_store=mock_sync_vector_store, + executor=executor, + ) + + mock_sync_vector_store.delete.assert_called_once() + + @pytest.mark.asyncio + async def test_sync_batched_no_rollback_on_first_error( + self, mock_documents, mock_sync_vector_store + ): + """Test that no rollback occurs if first batch fails.""" + from app.routes.document_routes import _process_documents_batched_sync + from concurrent.futures import ThreadPoolExecutor + + mock_sync_vector_store.add_documents = Mock(side_effect=Exception("DB error")) + + with ThreadPoolExecutor(max_workers=2) as executor: + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 3): + with pytest.raises(Exception): + await _process_documents_batched_sync( + documents=mock_documents, + file_id="test_file", + vector_store=mock_sync_vector_store, + executor=executor, + ) + + # Should not attempt rollback since nothing was inserted + assert not mock_sync_vector_store.delete.called + + +class TestBatchConfiguration: + """Test configuration and edge cases.""" + + def test_batch_calculation(self): + """Test batch count calculation using the utility function.""" + from app.routes.document_routes import calculate_num_batches + + # 10 docs, batch size 3 = 4 batches (3+3+3+1) + assert calculate_num_batches(10, 3) == 4 + + # Exact division + assert calculate_num_batches(9, 3) == 3 + + # Single item + assert calculate_num_batches(1, 3) == 1 + + # Zero items + assert calculate_num_batches(0, 3) == 0 + + # Batch size larger than total + assert calculate_num_batches(5, 10) == 1 + + # Edge case: batch_size of 0 returns 1 (fallback) + assert calculate_num_batches(10, 0) == 1 + + # Edge case: batch_size of 1 + assert calculate_num_batches(5, 1) == 5 + + def test_embedding_batch_size_from_env(self): + """Test that EMBEDDING_BATCH_SIZE is read from environment variable.""" + import os + from importlib import reload + + # Save current value + original = os.environ.get("EMBEDDING_BATCH_SIZE") + + try: + # Set a specific test value + os.environ["EMBEDDING_BATCH_SIZE"] = "999" + + import app.config as config_module + + reload(config_module) + + assert config_module.EMBEDDING_BATCH_SIZE == 999 + finally: + # Restore original value + if original is not None: + os.environ["EMBEDDING_BATCH_SIZE"] = original + elif "EMBEDDING_BATCH_SIZE" in os.environ: + del os.environ["EMBEDDING_BATCH_SIZE"] + + def test_embedding_max_queue_size_from_env(self): + """Test that EMBEDDING_MAX_QUEUE_SIZE is read from environment variable.""" + import os + from importlib import reload + + original = os.environ.get("EMBEDDING_MAX_QUEUE_SIZE") + + try: + os.environ["EMBEDDING_MAX_QUEUE_SIZE"] = "10" + + import app.config as config_module + + reload(config_module) + + assert config_module.EMBEDDING_MAX_QUEUE_SIZE == 10 + finally: + if original is not None: + os.environ["EMBEDDING_MAX_QUEUE_SIZE"] = original + elif "EMBEDDING_MAX_QUEUE_SIZE" in os.environ: + del os.environ["EMBEDDING_MAX_QUEUE_SIZE"] + + +class TestBatchSizeEdgeCases: + """Test various batch size configurations.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "batch_size,doc_count,expected_batches", + [ + (1, 5, 5), # Each doc separate + (5, 5, 1), # Exact fit + (10, 5, 1), # Batch larger than docs + (3, 10, 4), # Normal case + (100, 1, 1), # Large batch, single doc + ], + ) + async def test_batch_counts(self, batch_size, doc_count, expected_batches): + """Test various batch size and document count combinations.""" + from app.routes.document_routes import _process_documents_async_pipeline + + mock_store = AsyncMock() + mock_store.aadd_documents = AsyncMock(return_value=["id"]) + + docs = [ + Document(page_content=f"doc_{i}", metadata={}) for i in range(doc_count) + ] + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", batch_size): + await _process_documents_async_pipeline( + documents=docs, file_id="test", vector_store=mock_store, executor=None + ) + + assert mock_store.aadd_documents.call_count == expected_batches + + @pytest.mark.asyncio + async def test_large_batch_size_single_call(self): + """Test that a very large batch size results in a single call.""" + from app.routes.document_routes import _process_documents_async_pipeline + + mock_store = AsyncMock() + mock_store.aadd_documents = AsyncMock(return_value=["id"]) + + docs = [Document(page_content=f"doc_{i}", metadata={}) for i in range(100)] + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 1000): + await _process_documents_async_pipeline( + documents=docs, file_id="test", vector_store=mock_store, executor=None + ) + + assert mock_store.aadd_documents.call_count == 1 + + @pytest.mark.asyncio + async def test_batch_size_one_multiple_calls(self): + """Test that batch size of 1 results in many calls.""" + from app.routes.document_routes import _process_documents_async_pipeline + + mock_store = AsyncMock() + mock_store.aadd_documents = AsyncMock(return_value=["id"]) + + docs = [Document(page_content=f"doc_{i}", metadata={}) for i in range(5)] + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 1): + await _process_documents_async_pipeline( + documents=docs, file_id="test", vector_store=mock_store, executor=None + ) + + assert mock_store.aadd_documents.call_count == 5 + + +class TestProducerConsumerPattern: + """Test the producer-consumer pattern behavior.""" + + @pytest.mark.asyncio + async def test_producer_signals_completion_on_success(self): + """Test that producer always signals completion.""" + from app.routes.document_routes import _process_documents_async_pipeline + + mock_store = AsyncMock() + mock_store.aadd_documents = AsyncMock(return_value=["id"]) + + docs = [Document(page_content="test", metadata={})] + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 1): + result = await _process_documents_async_pipeline( + documents=docs, file_id="test", vector_store=mock_store, executor=None + ) + + # If we get here without hanging, the producer signaled completion + assert result == ["id"] + + @pytest.mark.asyncio + async def test_consumer_handles_exception_in_batch(self): + """Test that consumer properly handles exceptions from vector store.""" + from app.routes.document_routes import _process_documents_async_pipeline + + mock_store = AsyncMock() + mock_store.aadd_documents = AsyncMock(side_effect=ValueError("Test error")) + mock_store.delete = AsyncMock() + + docs = [Document(page_content="test", metadata={})] + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 1): + with pytest.raises(ValueError, match="Test error"): + await _process_documents_async_pipeline( + documents=docs, + file_id="test", + vector_store=mock_store, + executor=None, + ) + + @pytest.mark.asyncio + async def test_all_ids_collected_across_batches(self): + """Test that IDs from all batches are collected.""" + from app.routes.document_routes import _process_documents_async_pipeline + + mock_store = AsyncMock() + # Return different IDs for each batch + mock_store.aadd_documents = AsyncMock( + side_effect=[["id1", "id2"], ["id3", "id4"], ["id5"]] + ) + + docs = [Document(page_content=f"doc_{i}", metadata={}) for i in range(5)] + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 2): + result = await _process_documents_async_pipeline( + documents=docs, file_id="test", vector_store=mock_store, executor=None + ) + + assert len(result) == 5 + assert result == ["id1", "id2", "id3", "id4", "id5"] + + +class TestSyncBatchedMongoCompat: + """Regression tests for MongoDB-compatible batch processing (PR #266).""" + + @pytest.mark.asyncio + async def test_sync_batched_rejects_documents_keyword_with_legacy_store(self): + """Regression: a store using 'docs' (not 'documents') as its param name must + still work, proving the call site uses positional — not keyword — dispatch.""" + from app.routes.document_routes import _process_documents_batched_sync + from concurrent.futures import ThreadPoolExecutor + + class LegacyMongoStore: + """No **kwargs — documents= keyword would raise TypeError.""" + + def add_documents(self, docs, ids): + return [f"id_{i}" for i in range(len(docs))] + + def delete(self, ids=None): + pass + + docs = [ + Document(page_content=f"content_{i}", metadata={"file_id": "test_file"}) + for i in range(5) + ] + + with ThreadPoolExecutor(max_workers=1) as executor: + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 2): + result = await _process_documents_batched_sync( + documents=docs, + file_id="test_file", + vector_store=LegacyMongoStore(), + executor=executor, + ) + + assert len(result) == 5 + + @pytest.mark.asyncio + async def test_sync_batched_with_base_class_signature(self): + """Positional dispatch also works with a store matching the VectorStore base class signature.""" + from app.routes.document_routes import _process_documents_batched_sync + from concurrent.futures import ThreadPoolExecutor + + class BaseClassStore: + def add_documents(self, documents, ids=None, **kwargs): + return [f"id_{i}" for i in range(len(documents))] + + def delete(self, ids=None): + pass + + docs = [ + Document(page_content=f"content_{i}", metadata={"file_id": "test_file"}) + for i in range(3) + ] + + with ThreadPoolExecutor(max_workers=1) as executor: + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 2): + result = await _process_documents_batched_sync( + documents=docs, + file_id="test_file", + vector_store=BaseClassStore(), + executor=executor, + ) + + assert len(result) == 3 + + +class TestMongoIdGeneration: + """Test that digest-based ID generation produces unique IDs across batches.""" + + def test_unique_ids_across_batches(self): + """Simulate multiple batch calls and verify no ID collisions.""" + import hashlib + + file_id = "test_file" + all_ids = [] + + for batch_idx in range(3): + batch_docs = [ + Document( + page_content=f"content_{batch_idx * 3 + i}", + metadata={ + "file_id": file_id, + "digest": hashlib.md5( + f"content_{batch_idx * 3 + i}".encode() + ).hexdigest(), + }, + ) + for i in range(3) + ] + f_ids = [ + f"{file_id}_{doc.metadata.get('digest') or hashlib.md5(doc.page_content.encode()).hexdigest()}" + for doc in batch_docs + ] + all_ids.extend(f_ids) + + assert len(all_ids) == 9 + assert len(set(all_ids)) == 9 + + def test_old_sequential_ids_would_collide(self): + """Demonstrates the old per-batch range(len) approach caused ID collisions.""" + file_id = "test_file" + all_ids = [] + + for _ in range(3): + batch_size = 3 + old_ids = [f"{file_id}_{i}" for i in range(batch_size)] + all_ids.extend(old_ids) + + assert len(all_ids) == 9 + assert len(set(all_ids)) == 3 + + def test_fallback_to_content_hash_without_digest_metadata(self): + """IDs are unique even when documents lack a 'digest' metadata field.""" + import hashlib + + file_id = "test_file" + docs = [ + Document(page_content=f"content_{i}", metadata={"file_id": file_id}) + for i in range(5) + ] + + f_ids = [ + f"{file_id}_{doc.metadata.get('digest') or hashlib.md5(doc.page_content.encode()).hexdigest()}" + for doc in docs + ] + + assert len(set(f_ids)) == 5 + + def test_empty_documents_returns_empty(self): + """AtlasMongoVector.add_documents with empty list returns empty.""" + pytest.importorskip("langchain_mongodb", reason="requires langchain_mongodb") + from unittest.mock import MagicMock + from app.services.vector_store.atlas_mongo_vector import AtlasMongoVector + + store = MagicMock(spec=AtlasMongoVector) + store.add_documents = AtlasMongoVector.add_documents.__get__(store) + result = store.add_documents([]) + assert result == [] diff --git a/tests/test_batch_processing_integration.py b/tests/test_batch_processing_integration.py new file mode 100644 index 00000000..3b91e829 --- /dev/null +++ b/tests/test_batch_processing_integration.py @@ -0,0 +1,311 @@ +# tests/test_batch_processing_integration.py +""" +Integration tests for batch processing. + +These tests verify actual memory behavior and require more resources to run. +Mark with @pytest.mark.integration to skip in normal test runs. + +Run with: pytest tests/test_batch_processing_integration.py -v -m integration +""" +import pytest +import tracemalloc +from unittest.mock import Mock, AsyncMock, patch, MagicMock +from langchain_core.documents import Document + + +class TestMemoryOptimization: + """Tests to verify memory optimization behavior.""" + + @pytest.mark.asyncio + async def test_memory_bounded_by_batch_size(self): + """ + Test that memory usage is bounded by batch size, not total documents. + + This test verifies that processing many documents in batches doesn't + accumulate memory proportionally to the total document count. + """ + from app.routes.document_routes import _process_documents_async_pipeline + + # Track how many documents are in memory at any time + max_docs_in_memory = 0 + current_docs_in_memory = 0 + + async def tracking_add_documents(docs, ids=None, executor=None): + nonlocal max_docs_in_memory, current_docs_in_memory + current_docs_in_memory = len(docs) + max_docs_in_memory = max(max_docs_in_memory, current_docs_in_memory) + return [f"id_{i}" for i in range(len(docs))] + + mock_store = AsyncMock() + mock_store.aadd_documents = tracking_add_documents + mock_store.delete = AsyncMock() + + # Create 100 documents + docs = [ + Document(page_content=f"doc_{i}" * 100, metadata={"idx": i}) + for i in range(100) + ] + + # Process with batch size of 10 + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 10): + result = await _process_documents_async_pipeline( + documents=docs, file_id="test", vector_store=mock_store, executor=None + ) + + # Verify we got all 100 IDs back + assert len(result) == 100 + + # Verify max docs in memory was bounded by batch size (10), not total (100) + assert ( + max_docs_in_memory <= 10 + ), f"Expected max {10} docs in memory at once, but saw {max_docs_in_memory}" + + @pytest.mark.asyncio + async def test_memory_tracking_with_tracemalloc(self): + """ + Test memory usage with tracemalloc. + + This test uses Python's tracemalloc to verify memory behavior. + Note: This is a sanity check, not a strict memory bound test. + """ + from app.routes.document_routes import _process_documents_async_pipeline + + mock_store = AsyncMock() + mock_store.aadd_documents = AsyncMock(return_value=["id"]) + mock_store.delete = AsyncMock() + + # Create documents with substantial content + doc_count = 50 + docs = [ + Document( + page_content=f"Document content {i} " * 100, # ~2KB per doc + metadata={"file_id": "test", "idx": i}, + ) + for i in range(doc_count) + ] + + tracemalloc.start() + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 5): + await _process_documents_async_pipeline( + documents=docs, file_id="test", vector_store=mock_store, executor=None + ) + + current, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + + # Log memory usage for debugging + print(f"Current memory: {current / 1024:.2f} KB") + print(f"Peak memory: {peak / 1024:.2f} KB") + + # The test passes if it completes without OOM + # Actual memory bounds depend on Python internals and test environment + assert True + + @pytest.mark.asyncio + async def test_batch_processing_maintains_order(self): + """Test that document IDs are returned in correct order across batches.""" + from app.routes.document_routes import _process_documents_async_pipeline + + call_order = [] + + async def ordered_add_documents(docs, ids=None, executor=None): + batch_ids = [f"id_{docs[0].metadata['idx']}_to_{docs[-1].metadata['idx']}"] + call_order.append(docs[0].metadata["idx"]) + return [f"id_{d.metadata['idx']}" for d in docs] + + mock_store = AsyncMock() + mock_store.aadd_documents = ordered_add_documents + + docs = [ + Document(page_content=f"doc_{i}", metadata={"idx": i}) for i in range(15) + ] + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 5): + result = await _process_documents_async_pipeline( + documents=docs, file_id="test", vector_store=mock_store, executor=None + ) + + # Verify batches were processed in order + assert call_order == [0, 5, 10], f"Batches processed out of order: {call_order}" + + # Verify all IDs returned + assert len(result) == 15 + + +class TestSyncBatchedMemory: + """Memory tests for sync batched processing.""" + + @pytest.mark.asyncio + async def test_sync_memory_bounded_by_batch_size(self): + """Test that sync batch processing bounds memory by batch size.""" + from app.routes.document_routes import _process_documents_batched_sync + from concurrent.futures import ThreadPoolExecutor + + max_docs_in_batch = 0 + + def tracking_add_documents(documents, ids=None): + nonlocal max_docs_in_batch + max_docs_in_batch = max(max_docs_in_batch, len(documents)) + return [f"id_{i}" for i in range(len(documents))] + + mock_store = Mock() + mock_store.add_documents = tracking_add_documents + mock_store.delete = Mock() + + docs = [ + Document(page_content=f"doc_{i}" * 100, metadata={"idx": i}) + for i in range(50) + ] + + with ThreadPoolExecutor(max_workers=2) as executor: + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 10): + result = await _process_documents_batched_sync( + documents=docs, + file_id="test", + vector_store=mock_store, + executor=executor, + ) + + assert len(result) == 50 + assert ( + max_docs_in_batch <= 10 + ), f"Expected max {10} docs per batch, but saw {max_docs_in_batch}" + + +class TestBatchProcessingResilience: + """Tests for error handling and resilience.""" + + @pytest.mark.asyncio + async def test_partial_failure_tracks_inserted_ids(self): + """Test that we track which IDs were inserted before failure.""" + from app.routes.document_routes import _process_documents_async_pipeline + + inserted_batches = [] + + async def failing_add_documents(docs, ids=None, executor=None): + batch_num = len(inserted_batches) + 1 + if batch_num == 3: # Fail on third batch + raise Exception("Simulated DB error") + result = [f"batch{batch_num}_id_{i}" for i in range(len(docs))] + inserted_batches.append(result) + return result + + mock_store = AsyncMock() + mock_store.aadd_documents = failing_add_documents + mock_store.delete = AsyncMock() + + docs = [ + Document(page_content=f"doc_{i}", metadata={"idx": i}) for i in range(15) + ] + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 5): + with pytest.raises(Exception, match="Simulated DB error"): + await _process_documents_async_pipeline( + documents=docs, + file_id="test_file", + vector_store=mock_store, + executor=None, + ) + + # Verify rollback was called because we had inserted batches + mock_store.delete.assert_called_once() + + # Verify we inserted 2 batches before failure + assert len(inserted_batches) == 2 + + @pytest.mark.asyncio + async def test_rollback_called_with_correct_file_id(self): + """Test that rollback uses the correct file_id.""" + from app.routes.document_routes import _process_documents_async_pipeline + + async def failing_on_second(docs, ids=None, executor=None): + if len(docs) > 0 and docs[0].metadata.get("idx", 0) >= 5: + raise Exception("Fail") + return ["id1"] + + mock_store = AsyncMock() + mock_store.aadd_documents = failing_on_second + mock_store.delete = AsyncMock() + + docs = [ + Document(page_content=f"doc_{i}", metadata={"idx": i}) for i in range(10) + ] + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 5): + with pytest.raises(Exception): + await _process_documents_async_pipeline( + documents=docs, + file_id="my_unique_file_id", + vector_store=mock_store, + executor=None, + ) + + # Verify delete was called with the correct file_id + mock_store.delete.assert_called_once() + call_kwargs = mock_store.delete.call_args + assert call_kwargs[1]["ids"] == ["my_unique_file_id"] + + +class TestConfigurationBehavior: + """Tests for configuration-driven behavior.""" + + @pytest.mark.asyncio + async def test_batch_size_zero_uses_original_path(self): + """Test that EMBEDDING_BATCH_SIZE=0 uses the non-batched code path.""" + from app.routes.document_routes import store_data_in_vector_db + from app.services.vector_store.async_pg_vector import AsyncPgVector + + mock_store = AsyncMock(spec=AsyncPgVector) + mock_store.aadd_documents = AsyncMock(return_value=["id1", "id2"]) + + docs = [Document(page_content="test", metadata={})] + + with patch("app.routes.document_routes.vector_store", mock_store): + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", 0): + with patch("app.routes.document_routes.isinstance", return_value=True): + result = await store_data_in_vector_db( + data=docs, + file_id="test_file", + user_id="test_user", + executor=None, + ) + + # When batch size is 0, aadd_documents should be called directly + # (not through the pipeline) + assert mock_store.aadd_documents.called + + @pytest.mark.asyncio + async def test_different_batch_sizes_produce_correct_batches(self): + """Test that different batch sizes produce expected number of batches.""" + from app.routes.document_routes import _process_documents_async_pipeline + + test_cases = [ + (10, 100, 10), # 100 docs / 10 batch = 10 batches + (25, 100, 4), # 100 docs / 25 batch = 4 batches + (100, 100, 1), # 100 docs / 100 batch = 1 batch + (150, 100, 1), # 100 docs / 150 batch = 1 batch (batch larger than docs) + (7, 20, 3), # 20 docs / 7 batch = 3 batches (20 = 7+7+6) + ] + + for batch_size, doc_count, expected_batches in test_cases: + mock_store = AsyncMock() + mock_store.aadd_documents = AsyncMock(return_value=["id"]) + + docs = [ + Document(page_content=f"doc_{i}", metadata={}) for i in range(doc_count) + ] + + with patch("app.routes.document_routes.EMBEDDING_BATCH_SIZE", batch_size): + await _process_documents_async_pipeline( + documents=docs, + file_id="test", + vector_store=mock_store, + executor=None, + ) + + actual_batches = mock_store.aadd_documents.call_count + assert actual_batches == expected_batches, ( + f"batch_size={batch_size}, docs={doc_count}: " + f"expected {expected_batches} batches, got {actual_batches}" + ) diff --git a/tests/test_main.py b/tests/test_main.py index e3807019..a7fce89a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -7,6 +7,7 @@ from concurrent.futures import ThreadPoolExecutor from main import app +from app.routes import document_routes client = TestClient(app) @@ -27,6 +28,18 @@ def auth_headers(): @pytest.fixture(autouse=True) def override_vector_store(monkeypatch): from app.config import vector_store + from app.services.vector_store.async_pg_vector import AsyncPgVector + from app.routes import document_routes + + # Clear the LRU cache and patch the cached function to return dummy embeddings + document_routes.get_cached_query_embedding.cache_clear() + + def dummy_get_cached_query_embedding(query): + return [0.1, 0.2, 0.3] + + monkeypatch.setattr( + document_routes, "get_cached_query_embedding", dummy_get_cached_query_embedding + ) # Initialize thread pool for tests since TestClient doesn't run lifespan if not hasattr(app.state, "thread_pool") or app.state.thread_pool is None: @@ -34,31 +47,31 @@ def override_vector_store(monkeypatch): max_workers=2, thread_name_prefix="test-worker" ) - # Override get_all_ids as an async function. - async def dummy_get_all_ids(executor=None): + # Override get_all_ids as an async function - patch at CLASS level to bypass run_in_executor + async def dummy_get_all_ids(self, executor=None): return ["testid1", "testid2"] - monkeypatch.setattr(vector_store, "get_all_ids", dummy_get_all_ids) + monkeypatch.setattr(AsyncPgVector, "get_all_ids", dummy_get_all_ids) # Override get_filtered_ids as an async function. - async def dummy_get_filtered_ids(ids, executor=None): + async def dummy_get_filtered_ids(self, ids, executor=None): dummy_ids = ["testid1", "testid2"] return [id for id in dummy_ids if id in ids] - monkeypatch.setattr(vector_store, "get_filtered_ids", dummy_get_filtered_ids) + monkeypatch.setattr(AsyncPgVector, "get_filtered_ids", dummy_get_filtered_ids) # Override get_documents_by_ids as an async function. - async def dummy_get_documents_by_ids(ids, executor=None): + async def dummy_get_documents_by_ids(self, ids, executor=None): return [ Document(page_content="Test content", metadata={"file_id": id}) for id in ids ] monkeypatch.setattr( - vector_store, "get_documents_by_ids", dummy_get_documents_by_ids + AsyncPgVector, "get_documents_by_ids", dummy_get_documents_by_ids ) - # Override embedding_function. + # Override embedding_function with a dummy that doesn't call OpenAI class DummyEmbedding: def embed_query(self, query): return [0.1, 0.2, 0.3] @@ -66,7 +79,7 @@ def embed_query(self, query): vector_store.embedding_function = DummyEmbedding() # Override similarity search to return a tuple (Document, score). - def dummy_similarity_search_with_score_by_vector(embedding, k, filter): + def dummy_similarity_search_with_score_by_vector(self, embedding, k, filter): doc = Document( page_content="Queried content", metadata={ @@ -77,7 +90,7 @@ def dummy_similarity_search_with_score_by_vector(embedding, k, filter): return [(doc, 0.9)] async def dummy_asimilarity_search_with_score_by_vector( - embedding, k, filter=None, executor=None + self, embedding, k, filter=None, executor=None ): doc = Document( page_content="Queried content", @@ -89,31 +102,31 @@ async def dummy_asimilarity_search_with_score_by_vector( return [(doc, 0.9)] monkeypatch.setattr( - vector_store, + AsyncPgVector, "similarity_search_with_score_by_vector", dummy_similarity_search_with_score_by_vector, ) monkeypatch.setattr( - vector_store, + AsyncPgVector, "asimilarity_search_with_score_by_vector", dummy_asimilarity_search_with_score_by_vector, ) # Override document addition functions. - def dummy_add_documents(docs, ids): + def dummy_add_documents(self, docs, ids): return ids - async def dummy_aadd_documents(docs, ids=None, executor=None): + async def dummy_aadd_documents(self, docs, ids=None, executor=None): return ids - monkeypatch.setattr(vector_store, "add_documents", dummy_add_documents) - monkeypatch.setattr(vector_store, "aadd_documents", dummy_aadd_documents) + monkeypatch.setattr(AsyncPgVector, "add_documents", dummy_add_documents) + monkeypatch.setattr(AsyncPgVector, "aadd_documents", dummy_aadd_documents) # Override delete function. - async def dummy_delete(ids=None, collection_only=False, executor=None): + async def dummy_delete(self, ids=None, collection_only=False, executor=None): return None - monkeypatch.setattr(vector_store, "delete", dummy_delete) + monkeypatch.setattr(AsyncPgVector, "delete", dummy_delete) def test_get_all_ids(auth_headers): @@ -161,12 +174,15 @@ def test_query_embeddings_by_file_id(auth_headers): def test_embed_local_file(tmp_path, auth_headers, monkeypatch): - # Create a temporary file. + # Monkeypatch RAG_UPLOAD_DIR so the file is within the allowed directory. + monkeypatch.setattr(document_routes, "RAG_UPLOAD_DIR", str(tmp_path)) + + # Create a temporary file inside the patched upload dir. test_file = tmp_path / "test.txt" test_file.write_text("This is a test document.") data = { - "filepath": str(test_file), + "filepath": "test.txt", "filename": "test.txt", "file_content_type": "text/plain", "file_id": "testid1", diff --git a/tests/test_path_validation.py b/tests/test_path_validation.py new file mode 100644 index 00000000..cf0e399d --- /dev/null +++ b/tests/test_path_validation.py @@ -0,0 +1,411 @@ +""" +Tests for CVE-2025-68413 / CVE-2025-68414 path traversal fixes. + +Validates that validate_file_path() correctly prevents directory traversal, +symlink escape, and other path manipulation attacks on all four protected +endpoints: /local/embed, /embed, /embed-upload, /text. + +These tests are designed to catch regressions — if any test here fails, +the CVE fix is broken. +""" + +import os +import pytest +from pathlib import Path + +from app.routes.document_routes import validate_file_path + + +# --------------------------------------------------------------------------- +# Unit tests for validate_file_path() +# --------------------------------------------------------------------------- + + +class TestValidateFilePathTraversal: + """Ensure directory traversal attempts are rejected.""" + + @pytest.mark.parametrize( + "malicious_path", + [ + "../../etc/passwd", + "../../../etc/shadow", + "../sibling_dir/secret", + "subdir/../../../etc/passwd", + "subdir/./../../etc/passwd", + "valid/../../../etc/passwd", + ], + ids=[ + "dotdot-etc-passwd", + "triple-dotdot-etc-shadow", + "dotdot-sibling", + "subdir-then-escape", + "dot-slash-then-escape", + "valid-then-escape", + ], + ) + def test_traversal_attempts_rejected(self, tmp_path, malicious_path): + result = validate_file_path(str(tmp_path), malicious_path) + assert result is None, ( + f"Path traversal not blocked: validate_file_path({tmp_path!r}, {malicious_path!r}) " + f"returned {result!r} instead of None" + ) + + @pytest.mark.parametrize( + "absolute_path", + [ + "/etc/passwd", + "/etc/shadow", + "/tmp/evil", + "/root/.ssh/id_rsa", + ], + ids=[ + "abs-etc-passwd", + "abs-etc-shadow", + "abs-tmp-evil", + "abs-root-ssh", + ], + ) + def test_absolute_path_escape_rejected(self, tmp_path, absolute_path): + result = validate_file_path(str(tmp_path), absolute_path) + assert result is None, ( + f"Absolute path escape not blocked: validate_file_path({tmp_path!r}, {absolute_path!r}) " + f"returned {result!r} instead of None" + ) + + +class TestValidateFilePathPrefixBypass: + """ + Regression test for the startswith() prefix-matching vulnerability. + + If base_dir = "/app/uploads", a path like "/app/uploads_evil/file" + passes str.startswith("/app/uploads"). The fix must use path-boundary- + aware comparison (e.g. Path.relative_to or appending os.sep). + """ + + def test_sibling_directory_with_shared_prefix(self, tmp_path): + """CVE-2025-68413 core regression: sibling dir with same prefix.""" + base_dir = tmp_path / "uploads" + evil_dir = tmp_path / "uploads_evil" + base_dir.mkdir() + evil_dir.mkdir() + evil_file = evil_dir / "stolen.txt" + evil_file.write_text("sensitive data") + + result = validate_file_path(str(base_dir), str(evil_file)) + assert result is None, ( + f"Prefix bypass not blocked: sibling dir 'uploads_evil' was accessible " + f"from base 'uploads'. Got {result!r}" + ) + + def test_sibling_directory_relative_prefix_bypass(self, tmp_path): + """Same prefix attack via relative path component.""" + base_dir = tmp_path / "uploads" + evil_dir = tmp_path / "uploads2" + base_dir.mkdir() + evil_dir.mkdir() + evil_file = evil_dir / "data.txt" + evil_file.write_text("secret") + + # Try relative path that might resolve to uploads2/ + result = validate_file_path(str(base_dir), "../uploads2/data.txt") + assert result is None + + +class TestValidateFilePathSymlinks: + """ + Regression test for symlink traversal (os.path.abspath vs realpath). + + If base_dir contains a symlink pointing outside it, abspath won't + detect the escape but realpath/resolve will. + """ + + def test_symlink_escape(self, tmp_path): + """Symlink inside base_dir pointing to /tmp (outside base).""" + base_dir = tmp_path / "uploads" + base_dir.mkdir() + target_dir = tmp_path / "outside" + target_dir.mkdir() + secret = target_dir / "secret.txt" + secret.write_text("private data") + + link = base_dir / "escape_link" + link.symlink_to(target_dir) + + result = validate_file_path(str(base_dir), "escape_link/secret.txt") + assert result is None, ( + f"Symlink escape not blocked: 'escape_link' → {target_dir} " + f"was traversable. Got {result!r}" + ) + + def test_symlink_to_parent(self, tmp_path): + """Symlink pointing to parent directory.""" + base_dir = tmp_path / "uploads" + base_dir.mkdir() + + link = base_dir / "parent_link" + link.symlink_to(tmp_path) + + result = validate_file_path(str(base_dir), "parent_link/uploads/../secret") + assert result is None + + +class TestValidateFilePathEdgeCases: + """Edge cases and malformed input.""" + + def test_empty_string_rejected(self, tmp_path): + result = validate_file_path(str(tmp_path), "") + assert result is None, "Empty string should be rejected" + + def test_whitespace_only_rejected(self, tmp_path): + result = validate_file_path(str(tmp_path), " ") + assert result is None, "Whitespace-only path should be rejected" + + def test_dot_only_returns_none_or_base(self, tmp_path): + """A single dot resolves to base_dir itself — should be rejected (not a file).""" + result = validate_file_path(str(tmp_path), ".") + # Accepting the base dir itself is a design choice; either None or + # the base dir string is acceptable, but it must NOT escape. + if result is not None: + assert Path(result).resolve() == tmp_path.resolve() + + def test_null_byte_rejected(self, tmp_path): + """Null bytes in filenames should be rejected or cause no harm.""" + try: + result = validate_file_path(str(tmp_path), "file\x00.pdf") + # If it doesn't raise, it must return None + assert result is None, "Null byte in filename should be rejected" + except (ValueError, TypeError): + pass # Raising is also acceptable + + def test_very_long_path(self, tmp_path): + """Extremely long paths should not cause crashes.""" + long_name = "a" * 1000 + try: + result = validate_file_path(str(tmp_path), long_name) + # Either None or a valid path under base_dir + if result is not None: + assert result.startswith(str(tmp_path)) + except (OSError, ValueError): + pass # OS-level rejection is fine + + +class TestValidateFilePathValidInputs: + """Ensure legitimate paths are accepted correctly.""" + + def test_simple_filename(self, tmp_path): + result = validate_file_path(str(tmp_path), "document.pdf") + assert result is not None + assert Path(result).resolve().parent == tmp_path.resolve() + + def test_filename_with_spaces(self, tmp_path): + result = validate_file_path(str(tmp_path), "my document.pdf") + assert result is not None + assert str(tmp_path) in result + + def test_subdirectory_path(self, tmp_path): + subdir = tmp_path / "subdir" + subdir.mkdir() + result = validate_file_path(str(tmp_path), "subdir/file.txt") + assert result is not None + resolved = Path(result).resolve() + assert resolved.is_relative_to(tmp_path.resolve()) + + def test_returned_path_is_absolute(self, tmp_path): + result = validate_file_path(str(tmp_path), "test.txt") + assert result is not None + assert os.path.isabs(result), f"Expected absolute path, got {result!r}" + + +# --------------------------------------------------------------------------- +# Integration tests: endpoint-level path traversal via TestClient +# --------------------------------------------------------------------------- + +import datetime +import jwt +from fastapi.testclient import TestClient +from unittest.mock import patch, AsyncMock +from concurrent.futures import ThreadPoolExecutor + +from main import app + +client = TestClient(app) + + +@pytest.fixture +def auth_headers(): + jwt_secret = "testsecret" + os.environ["JWT_SECRET"] = jwt_secret + payload = { + "id": "testuser", + "exp": datetime.datetime.now(datetime.timezone.utc) + + datetime.timedelta(hours=1), + } + token = jwt.encode(payload, jwt_secret, algorithm="HS256") + return {"Authorization": f"Bearer {token}"} + + +@pytest.fixture(autouse=True) +def _setup_thread_pool(): + """Ensure app.state.thread_pool exists for tests.""" + if not hasattr(app.state, "thread_pool") or app.state.thread_pool is None: + app.state.thread_pool = ThreadPoolExecutor( + max_workers=2, thread_name_prefix="test-worker" + ) + + +class TestLocalEmbedPathTraversal: + """CVE-2025-68413: /local/embed path traversal via document.filepath.""" + + @pytest.mark.parametrize( + "filepath", + [ + "../../etc/passwd", + "/etc/passwd", + "../../../etc/shadow", + "subdir/../../../etc/passwd", + ], + ) + def test_traversal_rejected(self, auth_headers, filepath): + data = { + "filepath": filepath, + "filename": "evil.txt", + "file_content_type": "text/plain", + "file_id": "testid1", + } + response = client.post("/local/embed", json=data, headers=auth_headers) + # Should get 404 (file not found) or 400 (invalid) — NOT 200 + assert response.status_code in (400, 404), ( + f"Path traversal not blocked on /local/embed with filepath={filepath!r}. " + f"Got status {response.status_code}: {response.text}" + ) + + +class TestEntityIdPathTraversal: + """Path traversal via entity_id parameter poisoning temp_base_path.""" + + @pytest.mark.parametrize( + "entity_id", + [ + "../../etc", + "../../../", + "legit/../../../etc", + ], + ) + def test_embed_entity_id_traversal(self, auth_headers, entity_id, tmp_path): + test_file = tmp_path / "test.txt" + test_file.write_text("test content") + with test_file.open("rb") as f: + response = client.post( + "/embed", + data={"file_id": "testid1", "entity_id": entity_id}, + files={"file": ("safe.txt", f, "text/plain")}, + headers=auth_headers, + ) + assert response.status_code == 400, ( + f"entity_id traversal not blocked on /embed with entity_id={entity_id!r}. " + f"Got status {response.status_code}: {response.text}" + ) + + @pytest.mark.parametrize( + "entity_id", + [ + "../../etc", + "../../../", + "legit/../../../etc", + ], + ) + def test_text_entity_id_traversal(self, auth_headers, entity_id, tmp_path): + test_file = tmp_path / "test.txt" + test_file.write_text("test content") + with test_file.open("rb") as f: + response = client.post( + "/text", + data={"file_id": "testid1", "entity_id": entity_id}, + files={"file": ("safe.txt", f, "text/plain")}, + headers=auth_headers, + ) + assert response.status_code == 400, ( + f"entity_id traversal not blocked on /text with entity_id={entity_id!r}. " + f"Got status {response.status_code}: {response.text}" + ) + + +class TestEmbedPathTraversal: + """CVE-2025-68414: /embed path traversal via filename.""" + + @pytest.mark.parametrize( + "filename", + [ + "../../etc/passwd", + "../../../etc/shadow", + "/etc/passwd", + ], + ) + def test_traversal_rejected(self, auth_headers, filename, tmp_path): + test_file = tmp_path / "test.txt" + test_file.write_text("test content") + with test_file.open("rb") as f: + response = client.post( + "/embed", + data={"file_id": "testid1", "entity_id": "testuser"}, + files={"file": (filename, f, "text/plain")}, + headers=auth_headers, + ) + assert response.status_code == 400, ( + f"Path traversal not blocked on /embed with filename={filename!r}. " + f"Got status {response.status_code}: {response.text}" + ) + + +class TestEmbedUploadPathTraversal: + """CVE-2025-68414: /embed-upload path traversal via filename.""" + + @pytest.mark.parametrize( + "filename", + [ + "../../etc/passwd", + "../../../etc/shadow", + "/etc/passwd", + ], + ) + def test_traversal_rejected(self, auth_headers, filename, tmp_path): + test_file = tmp_path / "test.txt" + test_file.write_text("test content") + with test_file.open("rb") as f: + response = client.post( + "/embed-upload", + data={"file_id": "testid1", "entity_id": "testuser"}, + files={"uploaded_file": (filename, f, "text/plain")}, + headers=auth_headers, + ) + assert response.status_code == 400, ( + f"Path traversal not blocked on /embed-upload with filename={filename!r}. " + f"Got status {response.status_code}: {response.text}" + ) + + +class TestTextEndpointPathTraversal: + """CVE-2025-68414: /text path traversal via filename.""" + + @pytest.mark.parametrize( + "filename", + [ + "../../etc/passwd", + "../../../etc/shadow", + "/etc/passwd", + ], + ) + def test_traversal_rejected(self, auth_headers, filename, tmp_path): + test_file = tmp_path / "test.txt" + test_file.write_text("test content") + with test_file.open("rb") as f: + response = client.post( + "/text", + data={"file_id": "testid1", "entity_id": "testuser"}, + files={"file": (filename, f, "text/plain")}, + headers=auth_headers, + ) + assert response.status_code == 400, ( + f"Path traversal not blocked on /text with filename={filename!r}. " + f"Got status {response.status_code}: {response.text}" + ) diff --git a/tests/test_upload_isolation.py b/tests/test_upload_isolation.py new file mode 100644 index 00000000..5b1770bb --- /dev/null +++ b/tests/test_upload_isolation.py @@ -0,0 +1,88 @@ +""" +Tests for upload temp file isolation and generate_digest correctness. + +Validates: +- _make_unique_temp_path produces unique paths per call (no concurrent collisions) +- _make_unique_temp_path isolates users into separate subdirectories +- _make_unique_temp_path rejects path traversal filenames +- generate_digest is consistent for all string inputs including surrogates +""" + +import hashlib +import os +from pathlib import Path +import pytest + +from app.routes.document_routes import _make_unique_temp_path, generate_digest + + +class TestMakeUniqueTempPath: + """Ensure temp file paths are unique and user-isolated.""" + + def test_two_calls_same_filename_produce_different_paths( + self, monkeypatch, tmp_path + ): + monkeypatch.setattr("app.routes.document_routes.RAG_UPLOAD_DIR", str(tmp_path)) + path_a = _make_unique_temp_path("user1", "report.pdf") + path_b = _make_unique_temp_path("user1", "report.pdf") + assert path_a != path_b, "Same user+filename must produce unique paths" + + def test_different_users_produce_different_directories(self, monkeypatch, tmp_path): + monkeypatch.setattr("app.routes.document_routes.RAG_UPLOAD_DIR", str(tmp_path)) + path_a = _make_unique_temp_path("user1", "report.pdf") + path_b = _make_unique_temp_path("user2", "report.pdf") + assert os.path.dirname(path_a) != os.path.dirname(path_b) + assert Path(path_a).parent.name == "user1" + assert Path(path_b).parent.name == "user2" + + def test_preserves_file_extension(self, monkeypatch, tmp_path): + monkeypatch.setattr("app.routes.document_routes.RAG_UPLOAD_DIR", str(tmp_path)) + path = _make_unique_temp_path("user1", "data.csv") + assert path.endswith(".csv") + + def test_path_stays_within_upload_dir(self, monkeypatch, tmp_path): + monkeypatch.setattr("app.routes.document_routes.RAG_UPLOAD_DIR", str(tmp_path)) + path = _make_unique_temp_path("user1", "file.txt") + assert path.startswith(str(tmp_path)) + + @pytest.mark.parametrize( + "malicious_filename", + [ + "../../etc/passwd", + "../../../etc/shadow", + "/etc/passwd", + ], + ) + def test_rejects_path_traversal(self, monkeypatch, tmp_path, malicious_filename): + monkeypatch.setattr("app.routes.document_routes.RAG_UPLOAD_DIR", str(tmp_path)) + result = _make_unique_temp_path("user1", malicious_filename) + assert result is None + + +class TestGenerateDigest: + """Ensure generate_digest is correct for all inputs.""" + + def test_normal_string(self): + content = "hello world" + expected = hashlib.md5(content.encode("utf-8")).hexdigest() + assert generate_digest(content) == expected + + def test_empty_string(self): + expected = hashlib.md5(b"").hexdigest() + assert generate_digest("") == expected + + def test_unicode_content(self): + content = "café résumé naïve" + expected = hashlib.md5(content.encode("utf-8")).hexdigest() + assert generate_digest(content) == expected + + def test_surrogate_characters(self): + """Surrogate chars are stripped by encode('utf-8', 'ignore').""" + content = "hello\ud800world" + expected = hashlib.md5(content.encode("utf-8", "ignore")).hexdigest() + assert generate_digest(content) == expected + assert len(generate_digest(content)) == 32 + + def test_deterministic(self): + content = "same input" + assert generate_digest(content) == generate_digest(content) diff --git a/tests/utils/test_document_loader.py b/tests/utils/test_document_loader.py index 5276acc7..447af6f3 100644 --- a/tests/utils/test_document_loader.py +++ b/tests/utils/test_document_loader.py @@ -1,13 +1,18 @@ import os +from collections.abc import Iterator +from unittest.mock import MagicMock, patch + from app.utils.document_loader import get_loader, clean_text, process_documents from langchain_core.documents import Document + def test_clean_text(): text = "Hello\x00World" cleaned = clean_text(text) assert "\x00" not in cleaned assert cleaned == "HelloWorld" + def test_get_loader_text(tmp_path): # Create a temporary text file. file_path = tmp_path / "test.txt" @@ -19,16 +24,22 @@ def test_get_loader_text(tmp_path): # Check that data is loaded. assert data is not None + def test_process_documents(): docs = [ - Document(page_content="Page 1 content", metadata={"source": "dummy.txt", "page": 1}), - Document(page_content="Page 2 content", metadata={"source": "dummy.txt", "page": 2}), + Document( + page_content="Page 1 content", metadata={"source": "dummy.txt", "page": 1} + ), + Document( + page_content="Page 2 content", metadata={"source": "dummy.txt", "page": 2} + ), ] processed = process_documents(docs) assert "dummy.txt" in processed assert "# PAGE 1" in processed assert "# PAGE 2" in processed + def test_safe_pdf_loader_class(): """Test that SafePyPDFLoader class can be instantiated""" from app.utils.document_loader import SafePyPDFLoader @@ -39,16 +50,120 @@ def test_safe_pdf_loader_class(): assert loader.extract_images == True assert loader._temp_filepath is None + +def test_get_loader_text_lazy_load(tmp_path): + """Test that lazy_load returns an iterator yielding documents.""" + file_path = tmp_path / "test.txt" + file_path.write_text("Sample text") + loader, known_type, file_ext = get_loader("test.txt", "text/plain", str(file_path)) + assert known_type is True + assert file_ext == "txt" + data = list(loader.lazy_load()) + assert len(data) > 0 + assert hasattr(data[0], "page_content") + + def test_get_loader_pdf(tmp_path): """Test get_loader returns SafePyPDFLoader for PDF files""" # Create a dummy PDF file path (doesn't need to be real for this test) file_path = tmp_path / "test.pdf" file_path.write_text("dummy content") # Not a real PDF, but that's OK for this test - loader, known_type, file_ext = get_loader("test.pdf", "application/pdf", str(file_path)) + loader, known_type, file_ext = get_loader( + "test.pdf", "application/pdf", str(file_path) + ) # Check that we get our SafePyPDFLoader from app.utils.document_loader import SafePyPDFLoader + assert isinstance(loader, SafePyPDFLoader) assert known_type is True - assert file_ext == "pdf" \ No newline at end of file + assert file_ext == "pdf" + + +def test_safe_pdf_loader_lazy_load(): + """Test that SafePyPDFLoader.lazy_load() returns an Iterator.""" + from app.utils.document_loader import SafePyPDFLoader + + loader = SafePyPDFLoader("dummy.pdf", extract_images=False) + assert hasattr(loader, "lazy_load") + result = loader.lazy_load() + assert isinstance(result, Iterator) + + +def test_safe_pdf_loader_fallback_no_duplicate_pages(): + """Fallback after mid-stream KeyError must not duplicate already-yielded pages.""" + from app.utils.document_loader import SafePyPDFLoader + + fallback_docs = [Document(page_content=f"fallback page {i}") for i in range(5)] + + def primary_gen(): + yield Document(page_content="partial page 0") + yield Document(page_content="partial page 1") + raise KeyError("/Filter") + + def fallback_gen(): + yield from fallback_docs + + loader = SafePyPDFLoader("dummy.pdf", extract_images=True) + + with patch("app.utils.document_loader.PyPDFLoader") as MockPDF: + primary_instance = MagicMock() + primary_instance.lazy_load.side_effect = primary_gen + fallback_instance = MagicMock() + fallback_instance.lazy_load.side_effect = fallback_gen + MockPDF.side_effect = [primary_instance, fallback_instance] + + result = list(loader.lazy_load()) + + # Must be exactly the 5 fallback pages, NOT 2 partial + 5 fallback = 7 + assert len(result) == 5 + assert result[0].page_content == "fallback page 0" + assert result[-1].page_content == "fallback page 4" + + +def test_safe_pdf_loader_fallback_via_load(): + """load() delegates to lazy_load(), so fallback must also be correct via load().""" + from app.utils.document_loader import SafePyPDFLoader + + fallback_docs = [Document(page_content=f"fb {i}") for i in range(3)] + + def primary_gen(): + yield Document(page_content="partial 0") + raise KeyError("/Filter") + + def fallback_gen(): + yield from fallback_docs + + loader = SafePyPDFLoader("dummy.pdf", extract_images=True) + + with patch("app.utils.document_loader.PyPDFLoader") as MockPDF: + primary_instance = MagicMock() + primary_instance.lazy_load.side_effect = primary_gen + fallback_instance = MagicMock() + fallback_instance.lazy_load.side_effect = fallback_gen + MockPDF.side_effect = [primary_instance, fallback_instance] + + result = loader.load() + + assert len(result) == 3 + assert result[0].page_content == "fb 0" + + +def test_safe_pdf_loader_non_filter_error_propagates(): + """KeyError that isn't /Filter should propagate, not silently fallback.""" + from app.utils.document_loader import SafePyPDFLoader + import pytest + + def bad_gen(): + raise KeyError("SomeOtherKey") + + loader = SafePyPDFLoader("dummy.pdf", extract_images=True) + + with patch("app.utils.document_loader.PyPDFLoader") as MockPDF: + instance = MagicMock() + instance.lazy_load.side_effect = bad_gen + MockPDF.return_value = instance + + with pytest.raises(KeyError, match="SomeOtherKey"): + list(loader.lazy_load()) diff --git a/tests/utils/test_lazy_load.py b/tests/utils/test_lazy_load.py new file mode 100644 index 00000000..9506f95f --- /dev/null +++ b/tests/utils/test_lazy_load.py @@ -0,0 +1,617 @@ +""" +Tests for lazy_load() across all supported document loaders. + +Verifies that every loader returned by get_loader() supports lazy_load(), +returns a generator/iterator, and yields valid Document objects with content. +Also includes memory benchmarks comparing lazy_load() vs load() to document +which loaders benefit from streaming (CSV, PDF) vs which don't (Unstructured*). +""" + +import gc +import os +import shutil +import tracemalloc +import zipfile +from collections.abc import Iterator + +import pytest +from langchain_core.documents import Document + +from app.utils.document_loader import get_loader, SafePyPDFLoader + +# --------------------------------------------------------------------------- +# Environment checks — these deps aren't guaranteed in every CI runner +# --------------------------------------------------------------------------- + +_has_pandoc = shutil.which("pandoc") is not None +if not _has_pandoc: + try: + import pypandoc + + pypandoc.get_pandoc_path() + _has_pandoc = True # pypandoc_binary or similar bundles the binary + except (ImportError, OSError): + pass + +try: + import msoffcrypto # noqa: F401 + + _has_msoffcrypto = True +except ImportError: + _has_msoffcrypto = False + +_skip_no_pandoc = pytest.mark.skipif(not _has_pandoc, reason="pandoc not installed") +_skip_no_msoffcrypto = pytest.mark.skipif( + not _has_msoffcrypto, reason="msoffcrypto not installed" +) + + +# --------------------------------------------------------------------------- +# Fixture helpers for generating test files in each format +# --------------------------------------------------------------------------- + + +def _make_pdf(path, *, num_pages=1): + """Create a PDF with extractable text using pypdf. + + When *num_pages* > 1 each page gets unique text so tracemalloc can + observe the difference between holding 1 page vs N pages in memory. + """ + from pypdf import PdfWriter + from pypdf.generic import ( + DecodedStreamObject, + DictionaryObject, + NameObject, + ) + + writer = PdfWriter() + for i in range(num_pages): + writer.add_blank_page(width=612, height=792) + page = writer.pages[i] + + font_dict = DictionaryObject() + font_dict[NameObject("/Type")] = NameObject("/Font") + font_dict[NameObject("/Subtype")] = NameObject("/Type1") + font_dict[NameObject("/BaseFont")] = NameObject("/Helvetica") + + resources = DictionaryObject() + font_resources = DictionaryObject() + font_resources[NameObject("/F1")] = font_dict + resources[NameObject("/Font")] = font_resources + page[NameObject("/Resources")] = resources + + # Pad each page with unique filler so memory differences are measurable + filler = f"PAGE {i} " + ("X" * 2000) + content = f"BT /F1 12 Tf 100 700 Td ({filler}) Tj ET".encode() + stream = DecodedStreamObject() + stream.set_data(content) + page[NameObject("/Contents")] = stream + + with open(path, "wb") as f: + writer.write(f) + + +def _make_docx(path): + """Create a minimal DOCX (Office Open XML zip) with text content.""" + content_types = ( + '' + '' + ' ' + ' ' + ' ' + "" + ) + rels = ( + '' + '' + ' ' + "" + ) + document = ( + '' + '' + " " + " Hello from lazy_load DOCX test" + " " + "" + ) + word_rels = ( + '' + '' + "" + ) + with zipfile.ZipFile(path, "w", zipfile.ZIP_DEFLATED) as z: + z.writestr("[Content_Types].xml", content_types) + z.writestr("_rels/.rels", rels) + z.writestr("word/document.xml", document) + z.writestr("word/_rels/document.xml.rels", word_rels) + + +def _make_xlsx(path): + """Create a minimal XLSX with openpyxl.""" + from openpyxl import Workbook + + wb = Workbook() + ws = wb.active + ws["A1"] = "Name" + ws["B1"] = "Value" + ws["A2"] = "Test Item" + ws["B2"] = 42 + wb.save(path) + + +def _make_pptx(path): + """Create a minimal PPTX with python-pptx.""" + from pptx import Presentation + + prs = Presentation() + slide = prs.slides.add_slide(prs.slide_layouts[1]) + slide.shapes.title.text = "Test Slide" + slide.placeholders[1].text = "Hello from lazy_load PPTX test" + prs.save(path) + + +def _make_epub(path): + """Create a minimal EPUB 3 file (zip with OEBPS structure).""" + container_xml = ( + '' + '' + " " + ' ' + " " + "" + ) + content_opf = ( + '' + '' + ' ' + ' test-epub-lazy' + " Test EPUB" + " en" + ' 2024-01-01T00:00:00Z' + " " + " " + ' ' + ' ' + " " + " " + ' ' + " " + "" + ) + chapter1 = ( + '' + "" + '' + "Chapter 1" + "

Chapter 1

" + "

Hello from lazy_load EPUB test

" + "" + ) + nav = ( + '' + "" + '' + "Nav" + "" + '" + "" + ) + with zipfile.ZipFile(path, "w", zipfile.ZIP_DEFLATED) as z: + z.writestr("mimetype", "application/epub+zip") + z.writestr("META-INF/container.xml", container_xml) + z.writestr("OEBPS/content.opf", content_opf) + z.writestr("OEBPS/chapter1.xhtml", chapter1) + z.writestr("OEBPS/nav.xhtml", nav) + + +def _write_text(path, content): + """Write text to a file with proper handle cleanup.""" + with open(path, "w", encoding="utf-8") as f: + f.write(content) + + +def _make_large_csv(path, num_rows=500): + """Create a CSV with many rows so memory differences are measurable.""" + with open(path, "w", encoding="utf-8") as f: + f.write("id,name,description\n") + for i in range(num_rows): + # ~200 bytes per row + f.write(f'{i},item_{i},"{"D" * 150} row {i}"\n') + + +# --------------------------------------------------------------------------- +# Parametrized test: lazy_load() for every loader +# --------------------------------------------------------------------------- + +# (filename, content_type, file_creator_or_text, expected_substring) +LOADER_CASES = [ + pytest.param( + "test.txt", + "text/plain", + "Hello from lazy_load TXT test", + "Hello from lazy_load TXT test", + id="txt", + ), + pytest.param( + "test.csv", + "text/csv", + "name,value\nAlpha,1\nBravo,2\n", + "Alpha", + id="csv", + ), + pytest.param( + "test.json", + "application/json", + '{"key": "Hello from lazy_load JSON test"}', + "Hello from lazy_load JSON test", + id="json", + ), + pytest.param( + "test.md", + "text/markdown", + "# Heading\n\nHello from lazy_load MD test\n", + "Hello from lazy_load MD test", + id="md", + ), + pytest.param( + "test.rst", + "text/x-rst", + "Heading\n=======\n\nHello from lazy_load RST test\n", + "Hello from lazy_load RST test", + id="rst", + marks=_skip_no_pandoc, + ), + pytest.param( + "test.xml", + "application/xml", + 'Hello from lazy_load XML test', + "Hello from lazy_load XML test", + id="xml", + ), + pytest.param( + "test.py", + "text/x-python", + '# Hello from lazy_load PY test\nprint("hello")\n', + "Hello from lazy_load PY test", + id="py-source", + ), + pytest.param("test.pdf", "application/pdf", _make_pdf, "PAGE 0", id="pdf"), + pytest.param( + "test.docx", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + _make_docx, + "lazy_load DOCX test", + id="docx", + ), + pytest.param( + "test.xlsx", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + _make_xlsx, + "Test Item", + id="xlsx", + marks=_skip_no_msoffcrypto, + ), + pytest.param( + "test.pptx", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + _make_pptx, + "lazy_load PPTX test", + id="pptx", + ), + pytest.param( + "test.epub", + "application/epub+zip", + _make_epub, + "lazy_load EPUB test", + id="epub", + marks=_skip_no_pandoc, + ), +] + + +@pytest.mark.parametrize("filename,content_type,creator,expected_text", LOADER_CASES) +def test_lazy_load_returns_documents( + tmp_path, filename, content_type, creator, expected_text +): + """Every supported loader's lazy_load() should yield Document objects with content.""" + file_path = tmp_path / filename + + # Create the test file -- either write text or call a builder function + if callable(creator): + creator(str(file_path)) + else: + file_path.write_text(creator, encoding="utf-8") + + loader, known_type, file_ext = get_loader(filename, content_type, str(file_path)) + + # Verify the loader has lazy_load + assert hasattr(loader, "lazy_load"), f"{type(loader).__name__} missing lazy_load()" + + # Consume the lazy iterator + docs = list(loader.lazy_load()) + + # Basic assertions + assert len(docs) > 0, f"{type(loader).__name__} yielded 0 documents" + assert all( + isinstance(d, Document) for d in docs + ), "lazy_load() must yield Document instances" + + # Content assertion -- at least one doc should contain the expected text + all_text = " ".join(d.page_content for d in docs) + assert ( + expected_text in all_text + ), f"{type(loader).__name__}: expected '{expected_text}' in output, got: {all_text[:200]}" + + assert known_type is True + + +@pytest.mark.parametrize("filename,content_type,creator,expected_text", LOADER_CASES) +def test_lazy_load_matches_load( + tmp_path, filename, content_type, creator, expected_text +): + """lazy_load() consumed as a list should produce the same documents as load().""" + file_path = tmp_path / filename + + if callable(creator): + creator(str(file_path)) + else: + file_path.write_text(creator, encoding="utf-8") + + loader, _, _ = get_loader(filename, content_type, str(file_path)) + + eager_docs = loader.load() + # Re-create loader since some loaders are single-use or have internal state + loader2, _, _ = get_loader(filename, content_type, str(file_path)) + lazy_docs = list(loader2.lazy_load()) + + assert len(eager_docs) == len(lazy_docs), ( + f"{type(loader).__name__}: load() returned {len(eager_docs)} docs, " + f"lazy_load() returned {len(lazy_docs)}" + ) + + for i, (eager, lazy) in enumerate(zip(eager_docs, lazy_docs)): + assert ( + eager.page_content == lazy.page_content + ), f"{type(loader).__name__} doc[{i}]: content mismatch between load() and lazy_load()" + + +# --------------------------------------------------------------------------- +# SafePyPDFLoader-specific tests +# --------------------------------------------------------------------------- + + +def test_safe_pdf_loader_lazy_load_is_iterator(tmp_path): + """SafePyPDFLoader.lazy_load() should return an Iterator.""" + pdf_path = tmp_path / "gen_test.pdf" + _make_pdf(str(pdf_path)) + + loader = SafePyPDFLoader(str(pdf_path), extract_images=False) + result = loader.lazy_load() + assert isinstance(result, Iterator) + + # Consuming the iterator should yield documents + docs = list(result) + assert len(docs) > 0 + + +def test_safe_pdf_loader_load_delegates_to_lazy_load(tmp_path): + """SafePyPDFLoader.load() should produce the same results as list(lazy_load()).""" + pdf_path = tmp_path / "delegate_test.pdf" + _make_pdf(str(pdf_path)) + + loader1 = SafePyPDFLoader(str(pdf_path), extract_images=False) + loader2 = SafePyPDFLoader(str(pdf_path), extract_images=False) + + load_docs = loader1.load() + lazy_docs = list(loader2.lazy_load()) + + assert len(load_docs) == len(lazy_docs) + for ld, lz in zip(load_docs, lazy_docs): + assert ld.page_content == lz.page_content + + +# --------------------------------------------------------------------------- +# CSV with non-UTF-8 encoding +# --------------------------------------------------------------------------- + + +def test_lazy_load_csv_non_utf8(tmp_path): + """CSV files with non-UTF-8 encoding should still work via lazy_load().""" + csv_path = tmp_path / "latin1.csv" + csv_path.write_bytes("name,city\nJos\xe9,S\xe3o Paulo\n".encode("latin-1")) + + loader, known_type, _ = get_loader("latin1.csv", "text/csv", str(csv_path)) + docs = list(loader.lazy_load()) + + assert len(docs) > 0 + all_text = " ".join(d.page_content for d in docs) + # The text should have been converted to UTF-8 by get_loader + assert "Jos" in all_text + + +# --------------------------------------------------------------------------- +# Memory benchmarks: lazy_load() vs load() +# +# These use tracemalloc to measure peak memory. The streaming benchmarks +# (which discard Document objects as they iterate) represent the theoretical +# best-case for a future streaming pipeline. Current call sites materialize +# via list(), so they see no production benefit yet. +# +# Note: marked slow so they can be excluded from fast CI runs if needed. +# --------------------------------------------------------------------------- + + +def _measure_load(loader_factory): + """Run load() and return (docs, peak_memory_bytes).""" + gc.collect() + if tracemalloc.is_tracing(): + tracemalloc.stop() + tracemalloc.start() + + loader = loader_factory() + docs = loader.load() + + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + gc.collect() + return docs, peak + + +def _measure_lazy_load_streaming(loader_factory): + """Iterate lazy_load() and accumulate only the text, discarding Document + objects as we go. This simulates a true streaming consumer and represents + the theoretical best-case for lazy_load(). + """ + gc.collect() + if tracemalloc.is_tracing(): + tracemalloc.stop() + tracemalloc.start() + + loader = loader_factory() + texts = [] + for doc in loader.lazy_load(): + texts.append(doc.page_content) + + _, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + gc.collect() + return texts, peak + + +class TestMemoryBenchmarkPDF: + """Memory comparison for PyPDFLoader -- the loader most likely to benefit + from lazy_load() since it yields page-by-page.""" + + NUM_PAGES = 50 + + @pytest.fixture() + def pdf_path(self, tmp_path): + path = tmp_path / "bench.pdf" + _make_pdf(str(path), num_pages=self.NUM_PAGES) + return str(path) + + def test_pdf_streaming_lazy_load_peak_memory(self, pdf_path): + """Streaming lazy_load() should use <= peak memory vs load().""" + + def factory(): + return SafePyPDFLoader(pdf_path, extract_images=False) + + load_docs, peak_load = _measure_load(factory) + texts, peak_lazy_stream = _measure_lazy_load_streaming(factory) + + assert len(load_docs) == self.NUM_PAGES + assert len(texts) == self.NUM_PAGES + + # Streaming should not use MORE memory than eager (allow 10% noise) + assert peak_lazy_stream <= peak_load * 1.10, ( + f"streaming peak ({peak_lazy_stream:,} B) exceeded " + f"load() peak ({peak_load:,} B) by >10%" + ) + + +class TestMemoryBenchmarkCSV: + """Memory comparison for CSVLoader -- yields one Document per row.""" + + NUM_ROWS = 500 + + @pytest.fixture() + def csv_path(self, tmp_path): + path = tmp_path / "bench.csv" + _make_large_csv(str(path), num_rows=self.NUM_ROWS) + return str(path) + + def test_csv_streaming_lazy_load_peak_memory(self, csv_path): + """Streaming lazy_load() should use significantly less peak memory + than load() for CSVs with many rows.""" + + def factory(): + from langchain_community.document_loaders import CSVLoader + + return CSVLoader(csv_path) + + load_docs, peak_load = _measure_load(factory) + texts, peak_lazy_stream = _measure_lazy_load_streaming(factory) + + assert len(load_docs) == self.NUM_ROWS + assert len(texts) == self.NUM_ROWS + + # CSV streaming should use meaningfully less memory + assert peak_lazy_stream <= peak_load * 1.10, ( + f"streaming peak ({peak_lazy_stream:,} B) exceeded " + f"load() peak ({peak_load:,} B) by >10%" + ) + + +UNSTRUCTURED_CASES = [ + pytest.param("test.md", "text/markdown", id="md"), + pytest.param("test.xml", "application/xml", id="xml"), + pytest.param("test.rst", "text/x-rst", id="rst", marks=_skip_no_pandoc), + pytest.param( + "test.epub", + "application/epub+zip", + id="epub", + marks=_skip_no_pandoc, + ), + pytest.param( + "test.xlsx", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + id="xlsx", + marks=_skip_no_msoffcrypto, + ), + pytest.param( + "test.pptx", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + id="pptx", + ), + pytest.param( + "test.docx", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + id="docx", + ), +] + +# Map extensions to their file creators +_UNSTRUCTURED_CREATORS = { + ".md": lambda p: _write_text(p, "# Heading\n\n" + ("word " * 500) + "\n"), + ".xml": lambda p: _write_text( + p, + '' + + "".join(f"item {i}" for i in range(100)) + + "", + ), + ".rst": lambda p: _write_text( + p, + "Title\n=====\n\n" + + "\n\n".join(f"Paragraph {i}. " + "text " * 50 for i in range(20)), + ), + ".epub": lambda p: _make_epub(p), + ".xlsx": lambda p: _make_xlsx(p), + ".pptx": lambda p: _make_pptx(p), + ".docx": lambda p: _make_docx(p), +} + + +@pytest.mark.parametrize("filename,content_type", UNSTRUCTURED_CASES) +def test_unstructured_lazy_load_no_memory_benefit(tmp_path, filename, content_type): + """Unstructured-based loaders internally load the full file regardless of + lazy_load() vs load(). Verify lazy_load() doc count matches load().""" + file_path = tmp_path / filename + ext = os.path.splitext(filename)[1] + _UNSTRUCTURED_CREATORS[ext](str(file_path)) + + def factory(): + loader, _, _ = get_loader(filename, content_type, str(file_path)) + return loader + + load_docs, _ = _measure_load(factory) + texts, _ = _measure_lazy_load_streaming(factory) + + assert len(load_docs) == len(texts)