diff --git a/docs/openapi.json b/docs/openapi.json index c32d109b6..00c01a4e8 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -348,6 +348,58 @@ } } }, + "/v1/rags/{rag_id}": { + "get": { + "tags": [ + "rags" + ], + "summary": "Get Rag Endpoint Handler", + "description": "Retrieve a single RAG by its unique ID.\n\nRaises:\n HTTPException:\n - 404 if RAG with the given ID is not found,\n - 500 if unable to connect to Llama Stack,\n - 500 for any unexpected retrieval errors.\n\nReturns:\n RAGResponse: A single RAG's details", + "operationId": "get_rag_endpoint_handler_v1_rags__rag_id__get", + "parameters": [ + { + "name": "rag_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Rag Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RAGInfoResponse" + } + } + } + }, + "404": { + "response": "RAG with given id not found", + "description": "Not Found" + }, + "500": { + "response": "Unable to retrieve list of RAGs", + "cause": "Connection to Llama Stack is broken", + "description": "Internal Server Error" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/v1/query": { "post": { "tags": [ @@ -3918,6 +3970,108 @@ "title": "RAGChunk", "description": "Model representing a RAG chunk used in the response." }, + "RAGInfoResponse": { + "properties": { + "id": { + "type": "string", + "title": "Id", + "description": "Vector DB unique ID", + "examples": [ + "vs_00000000_0000_0000" + ] + }, + "name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Name", + "description": "Human readable vector DB name", + "examples": [ + "Faiss Store with Knowledge base" + ] + }, + "created_at": { + "type": "integer", + "title": "Created At", + "description": "When the vector store was created, represented as Unix time", + "examples": [ + 1763391371 + ] + }, + "last_active_at": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Last Active At", + "description": "When the vector store was last active, represented as Unix time", + "examples": [ + 1763391371 + ] + }, + "usage_bytes": { + "type": "integer", + "title": "Usage Bytes", + "description": "Storage byte(s) used by this vector DB", + "examples": [ + 0 + ] + }, + "expires_at": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Expires At", + "description": "When the vector store expires, represented as Unix time", + "examples": [ + 1763391371 + ] + }, + "object": { + "type": "string", + "title": "Object", + "description": "Object type", + "examples": [ + "vector_store" + ] + }, + "status": { + "type": "string", + "title": "Status", + "description": "Vector DB status", + "examples": [ + "completed" + ] + } + }, + "type": "object", + "required": [ + "id", + "name", + "created_at", + "last_active_at", + "usage_bytes", + "expires_at", + "object", + "status" + ], + "title": "RAGInfoResponse", + "description": "Model representing a response with information about RAG DB." + }, "RAGListResponse": { "properties": { "rags": { diff --git a/docs/openapi.md b/docs/openapi.md index fd18010b2..791a335e8 100644 --- a/docs/openapi.md +++ b/docs/openapi.md @@ -220,6 +220,38 @@ Returns: |-------------|-------------|-----------| | 200 | Successful Response | [RAGListResponse](#raglistresponse) | | 500 | Connection to Llama Stack is broken | | +## GET `/v1/rags/{rag_id}` + +> **Get Rag Endpoint Handler** + +Retrieve a single RAG by its unique ID. + +Raises: + HTTPException: + - 404 if RAG with the given ID is not found, + - 500 if unable to connect to Llama Stack, + - 500 for any unexpected retrieval errors. + +Returns: + RAGResponse: A single RAG's details + + + +### 🔗 Parameters + +| Name | Type | Required | Description | +|------|------|----------|-------------| +| rag_id | string | True | | + + +### ✅ Responses + +| Status Code | Description | Component | +|-------------|-------------|-----------| +| 200 | Successful Response | [RAGInfoResponse](#raginforesponse) | +| 404 | Not Found | | +| 500 | Internal Server Error | | +| 422 | Validation Error | [HTTPValidationError](#httpvalidationerror) | ## POST `/v1/query` > **Query Endpoint Handler** @@ -1688,6 +1720,24 @@ Model representing a RAG chunk used in the response. | score | | Relevance score | +## RAGInfoResponse + + +Model representing a response with information about RAG DB. + + +| Field | Type | Description | +|-------|------|-------------| +| id | string | Vector DB unique ID | +| name | | Human readable vector DB name | +| created_at | integer | When the vector store was created, represented as Unix time | +| last_active_at | | When the vector store was last active, represented as Unix time | +| usage_bytes | integer | Storage byte(s) used by this vector DB | +| expires_at | | When the vector store expires, represented as Unix time | +| object | string | Object type | +| status | string | Vector DB status | + + ## RAGListResponse diff --git a/docs/output.md b/docs/output.md index fd18010b2..791a335e8 100644 --- a/docs/output.md +++ b/docs/output.md @@ -220,6 +220,38 @@ Returns: |-------------|-------------|-----------| | 200 | Successful Response | [RAGListResponse](#raglistresponse) | | 500 | Connection to Llama Stack is broken | | +## GET `/v1/rags/{rag_id}` + +> **Get Rag Endpoint Handler** + +Retrieve a single RAG by its unique ID. + +Raises: + HTTPException: + - 404 if RAG with the given ID is not found, + - 500 if unable to connect to Llama Stack, + - 500 for any unexpected retrieval errors. + +Returns: + RAGResponse: A single RAG's details + + + +### 🔗 Parameters + +| Name | Type | Required | Description | +|------|------|----------|-------------| +| rag_id | string | True | | + + +### ✅ Responses + +| Status Code | Description | Component | +|-------------|-------------|-----------| +| 200 | Successful Response | [RAGInfoResponse](#raginforesponse) | +| 404 | Not Found | | +| 500 | Internal Server Error | | +| 422 | Validation Error | [HTTPValidationError](#httpvalidationerror) | ## POST `/v1/query` > **Query Endpoint Handler** @@ -1688,6 +1720,24 @@ Model representing a RAG chunk used in the response. | score | | Relevance score | +## RAGInfoResponse + + +Model representing a response with information about RAG DB. + + +| Field | Type | Description | +|-------|------|-------------| +| id | string | Vector DB unique ID | +| name | | Human readable vector DB name | +| created_at | integer | When the vector store was created, represented as Unix time | +| last_active_at | | When the vector store was last active, represented as Unix time | +| usage_bytes | integer | Storage byte(s) used by this vector DB | +| expires_at | | When the vector store expires, represented as Unix time | +| object | string | Object type | +| status | string | Vector DB status | + + ## RAGListResponse diff --git a/src/app/endpoints/rags.py b/src/app/endpoints/rags.py index e09fdfe91..d1c060dba 100644 --- a/src/app/endpoints/rags.py +++ b/src/app/endpoints/rags.py @@ -13,7 +13,7 @@ from client import AsyncLlamaStackClientHolder from configuration import configuration from models.config import Action -from models.responses import RAGListResponse +from models.responses import RAGListResponse, RAGInfoResponse from utils.endpoints import check_configuration_loaded logger = logging.getLogger(__name__) @@ -31,6 +31,15 @@ 500: {"description": "Connection to Llama Stack is broken"}, } +rag_responses: dict[int | str, dict[str, Any]] = { + 200: {}, + 404: {"response": "RAG with given id not found"}, + 500: { + "response": "Unable to retrieve list of RAGs", + "cause": "Connection to Llama Stack is broken", + }, +} + @router.get("/rags", responses=rags_responses) @authorize(Action.LIST_RAGS) @@ -94,3 +103,72 @@ async def rags_endpoint_handler( "cause": str(e), }, ) from e + + +@router.get("/rags/{rag_id}", responses=rag_responses) +@authorize(Action.GET_RAG) +async def get_rag_endpoint_handler( + request: Request, + rag_id: str, + auth: Annotated[AuthTuple, Depends(get_auth_dependency())], +) -> RAGInfoResponse: + """Retrieve a single RAG by its unique ID. + + Raises: + HTTPException: + - 404 if RAG with the given ID is not found, + - 500 if unable to connect to Llama Stack, + - 500 for any unexpected retrieval errors. + + Returns: + RAGInfoResponse: A single RAG's details + """ + # Used only by the middleware + _ = auth + + # Nothing interesting in the request + _ = request + + check_configuration_loaded(configuration) + + llama_stack_configuration = configuration.llama_stack_configuration + logger.info("Llama stack config: %s", llama_stack_configuration) + + try: + # try to get Llama Stack client + client = AsyncLlamaStackClientHolder().get_client() + # retrieve info about RAG + rag_info = await client.vector_stores.retrieve(rag_id) + return RAGInfoResponse( + id=rag_info.id, + name=rag_info.name, + created_at=rag_info.created_at, + last_active_at=rag_info.last_active_at, + expires_at=rag_info.expires_at, + object=rag_info.object, + status=rag_info.status, + usage_bytes=rag_info.usage_bytes, + ) + + # connection to Llama Stack server + except HTTPException: + raise + except APIConnectionError as e: + logger.error("Unable to connect to Llama Stack: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to connect to Llama Stack", + "cause": str(e), + }, + ) from e + # any other exception that can occur during model listing + except Exception as e: + logger.error("Unable to retrieve info about RAG: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to retrieve info about RAG", + "cause": str(e), + }, + ) from e diff --git a/src/models/responses.py b/src/models/responses.py index 06ef87595..e05fc114f 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -87,6 +87,49 @@ class ShieldsResponse(BaseModel): ) +class RAGInfoResponse(BaseModel): + """Model representing a response with information about RAG DB.""" + + id: str = Field( + ..., description="Vector DB unique ID", examples=["vs_00000000_0000_0000"] + ) + name: Optional[str] = Field( + None, + description="Human readable vector DB name", + examples=["Faiss Store with Knowledge base"], + ) + created_at: int = Field( + ..., + description="When the vector store was created, represented as Unix time", + examples=[1763391371], + ) + last_active_at: Optional[int] = Field( + None, + description="When the vector store was last active, represented as Unix time", + examples=[1763391371], + ) + usage_bytes: int = Field( + ..., + description="Storage byte(s) used by this vector DB", + examples=[0], + ) + expires_at: Optional[int] = Field( + None, + description="When the vector store expires, represented as Unix time", + examples=[1763391371], + ) + object: str = Field( + ..., + description="Object type", + examples=["vector_store"], + ) + status: str = Field( + ..., + description="Vector DB status", + examples=["completed"], + ) + + class RAGListResponse(BaseModel): """Model representing a response to list RAGs request.""" diff --git a/tests/unit/app/endpoints/test_rags.py b/tests/unit/app/endpoints/test_rags.py index 4beebf377..d7f0766c0 100644 --- a/tests/unit/app/endpoints/test_rags.py +++ b/tests/unit/app/endpoints/test_rags.py @@ -8,6 +8,7 @@ from authentication.interface import AuthTuple from app.endpoints.rags import ( + get_rag_endpoint_handler, rags_endpoint_handler, ) @@ -108,3 +109,131 @@ def __init__(self) -> None: response = await rags_endpoint_handler(request=request, auth=auth) assert len(response.rags) == 3 + + +@pytest.mark.asyncio +async def test_rag_info_endpoint_configuration_not_loaded( + mocker: MockerFixture, +) -> None: + """Test that /rags/{rag_id} endpoint raises HTTP 500 if configuration is not loaded.""" + mocker.patch("app.endpoints.rags.configuration", None) + request = Request(scope={"type": "http"}) + + # Authorization tuple required by URL endpoint handler + auth: AuthTuple = ("test_user_id", "test_user", True, "test_token") + + with pytest.raises(HTTPException) as e: + await get_rag_endpoint_handler(request=request, auth=auth, rag_id="xyzzy") + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_rag_info_endpoint_rag_not_found(mocker: MockerFixture) -> None: + """Test that /rags/{rag_id} endpoint returns HTTP 404 when the requested RAG is not found.""" + mock_client = mocker.AsyncMock() + mock_client.vector_stores.retrieve.side_effect = HTTPException( + status_code=status.HTTP_404_NOT_FOUND + ) # type: ignore + mocker.patch( + "app.endpoints.rags.AsyncLlamaStackClientHolder" + ).return_value.get_client.return_value = mock_client + + request = Request(scope={"type": "http"}) + + # Authorization tuple required by URL endpoint handler + auth: AuthTuple = ("test_user_id", "test_user", True, "test_token") + + with pytest.raises(HTTPException) as e: + await get_rag_endpoint_handler(request=request, auth=auth, rag_id="xyzzy") + assert e.value.status_code == status.HTTP_404_NOT_FOUND + + +@pytest.mark.asyncio +async def test_rag_info_endpoint_connection_error(mocker: MockerFixture) -> None: + """Test that /rags/{rag_id} endpoint raises HTTP 500 if Llama Stack connection fails.""" + mock_client = mocker.AsyncMock() + mock_client.vector_stores.retrieve.side_effect = APIConnectionError( + request=None # type: ignore + ) + mocker.patch( + "app.endpoints.rags.AsyncLlamaStackClientHolder" + ).return_value.get_client.return_value = mock_client + + request = Request(scope={"type": "http"}) + + # Authorization tuple required by URL endpoint handler + auth: AuthTuple = ("test_user_id", "test_user", True, "test_token") + + with pytest.raises(HTTPException) as e: + await get_rag_endpoint_handler(request=request, auth=auth, rag_id="xyzzy") + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + detail = e.value.detail + assert isinstance(detail, dict) + assert "response" in detail + assert "Unable to connect to Llama Stack" in detail["response"] + + +@pytest.mark.asyncio +async def test_rag_info_endpoint_unable_to_retrieve_list(mocker: MockerFixture) -> None: + """Test that /rags/{rag_id} endpoint raises HTTP 500 if Llama Stack connection fails.""" + mock_client = mocker.AsyncMock() + mock_client.vector_stores.retrieve.side_effect = [] # type: ignore + mocker.patch( + "app.endpoints.rags.AsyncLlamaStackClientHolder" + ).return_value.get_client.return_value = mock_client + + request = Request(scope={"type": "http"}) + + # Authorization tuple required by URL endpoint handler + auth: AuthTuple = ("test_user_id", "test_user", True, "test_token") + + with pytest.raises(HTTPException) as e: + await get_rag_endpoint_handler(request=request, auth=auth, rag_id="xyzzy") + assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + detail = e.value.detail + assert isinstance(detail, dict) + assert "response" in detail + assert "Unable to retrieve info about RAG" in detail["response"] + + +@pytest.mark.asyncio +async def test_rag_info_endpoint_success(mocker: MockerFixture) -> None: + """Test that /rags/{rag_id} endpoint returns information about selected RAG.""" + + # pylint: disable=R0902 + # pylint: disable=R0903 + class RagInfo: + """RagInfo mock.""" + + def __init__(self) -> None: + self.id = "xyzzy" + self.name = "rag_name" + self.created_at = 123456 + self.last_active_at = 1234567 + self.expires_at = 12345678 + self.object = "faiss" + self.status = "completed" + self.usage_bytes = 100 + + mock_client = mocker.AsyncMock() + mock_client.vector_stores.retrieve.return_value = RagInfo() + mocker.patch( + "app.endpoints.rags.AsyncLlamaStackClientHolder" + ).return_value.get_client.return_value = mock_client + + request = Request(scope={"type": "http"}) + + # Authorization tuple required by URL endpoint handler + auth: AuthTuple = ("test_user_id", "test_user", True, "test_token") + + response = await get_rag_endpoint_handler( + request=request, auth=auth, rag_id="xyzzy" + ) + assert response.id == "xyzzy" + assert response.name == "rag_name" + assert response.created_at == 123456 + assert response.last_active_at == 1234567 + assert response.expires_at == 12345678 + assert response.object == "faiss" + assert response.status == "completed" + assert response.usage_bytes == 100