Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 56 additions & 16 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,20 @@
import logging
import os
import warnings
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, ValuesView, cast
from typing import (
Any,
AsyncGenerator,
Callable,
Iterable,
Literal,
Mapping,
Optional,
Type,
TypeVar,
Union,
ValuesView,
cast,
)

import boto3
from botocore.config import Config as BotocoreConfig
Expand Down Expand Up @@ -493,23 +506,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationsContentBlock.html
if "citationsContent" in content:
citations = content["citationsContent"]
result = {}
citations_result: dict[str, Any] = {}

if "citations" in citations:
result["citations"] = []
citations_result["citations"] = []
for citation in citations["citations"]:
filtered_citation: dict[str, Any] = {}
if "location" in citation:
location = citation["location"]
filtered_location = {}
# Filter location fields to only include Bedrock-supported ones
if "documentIndex" in location:
filtered_location["documentIndex"] = location["documentIndex"]
if "start" in location:
filtered_location["start"] = location["start"]
if "end" in location:
filtered_location["end"] = location["end"]
filtered_citation["location"] = filtered_location
filtered_location = self._format_citation_location(citation["location"])
if filtered_location:
filtered_citation["location"] = filtered_location
if "sourceContent" in citation:
filtered_source_content: list[dict[str, Any]] = []
for source_content in citation["sourceContent"]:
Expand All @@ -519,20 +525,54 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
filtered_citation["sourceContent"] = filtered_source_content
if "title" in citation:
filtered_citation["title"] = citation["title"]
result["citations"].append(filtered_citation)
citations_result["citations"].append(filtered_citation)

if "content" in citations:
filtered_content: list[dict[str, Any]] = []
for generated_content in citations["content"]:
if "text" in generated_content:
filtered_content.append({"text": generated_content["text"]})
if filtered_content:
result["content"] = filtered_content
citations_result["content"] = filtered_content

return {"citationsContent": result}
return {"citationsContent": citations_result}

raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")

def _format_citation_location(self, location: Mapping[str, Any]) -> dict[str, Any]:
"""Format a citation location preserving the tagged union structure.

The Bedrock API requires CitationLocation to be a tagged union with exactly one
of the following keys: web, documentChar, documentPage, documentChunk, or
searchResultLocation.

Args:
location: Citation location to format.

Returns:
Formatted location with tagged union structure preserved, or empty dict if invalid.

See:
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationLocation.html
"""
# Allowed fields for each tagged union type
allowed_fields = {
"web": ("url", "domain"),
"documentChar": ("documentIndex", "start", "end"),
"documentPage": ("documentIndex", "start", "end"),
"documentChunk": ("documentIndex", "start", "end"),
"searchResultLocation": ("searchResultIndex", "start", "end"),
}

for location_type, fields in allowed_fields.items():
if location_type in location:
inner = location[location_type]
filtered = {k: v for k, v in inner.items() if k in fields}
return {location_type: filtered} if filtered else {}

logger.debug("location_type=<unknown> | unrecognized citation location type, skipping")
return {}

def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
"""Check if guardrail data contains any blocked policies.

Expand Down
106 changes: 90 additions & 16 deletions src/strands/types/citations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Citation type definitions for the SDK.

These types are modeled after the Bedrock API.
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CitationLocation.html
"""

from typing import List, Union
Expand All @@ -18,11 +19,8 @@ class CitationsConfig(TypedDict):
enabled: bool


class DocumentCharLocation(TypedDict, total=False):
"""Specifies a character-level location within a document.

Provides precise positioning information for cited content using
start and end character indices.
class DocumentCharLocationInner(TypedDict, total=False):
"""Inner content for character-level location within a document.

Attributes:
documentIndex: The index of the document within the array of documents
Expand All @@ -38,11 +36,8 @@ class DocumentCharLocation(TypedDict, total=False):
end: int


class DocumentChunkLocation(TypedDict, total=False):
"""Specifies a chunk-level location within a document.

Provides positioning information for cited content using logical
document segments or chunks.
class DocumentChunkLocationInner(TypedDict, total=False):
"""Inner content for chunk-level location within a document.

Attributes:
documentIndex: The index of the document within the array of documents
Expand All @@ -58,10 +53,8 @@ class DocumentChunkLocation(TypedDict, total=False):
end: int


class DocumentPageLocation(TypedDict, total=False):
"""Specifies a page-level location within a document.

Provides positioning information for cited content using page numbers.
class DocumentPageLocationInner(TypedDict, total=False):
"""Inner content for page-level location within a document.

Attributes:
documentIndex: The index of the document within the array of documents
Expand All @@ -77,8 +70,89 @@ class DocumentPageLocation(TypedDict, total=False):
end: int


# Union type for citation locations
CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation]
class WebLocationInner(TypedDict, total=False):
"""Inner content for web-based location.

Attributes:
url: The URL of the web page containing the cited content.
domain: The domain of the web page containing the cited content.
"""

url: str
domain: str


class SearchResultLocationInner(TypedDict, total=False):
"""Inner content for search result location.

Attributes:
searchResultIndex: The index of the search result content block where
the cited content is found. Minimum value of 0.
start: The starting position in the content array where the cited
content begins. Minimum value of 0.
end: The ending position in the content array where the cited
content ends. Minimum value of 0.
"""

searchResultIndex: int
start: int
end: int


class DocumentCharLocation(TypedDict, total=False):
"""Tagged union wrapper for character-level document location.

Attributes:
documentChar: The character-level location data.
"""

documentChar: DocumentCharLocationInner


class DocumentChunkLocation(TypedDict, total=False):
"""Tagged union wrapper for chunk-level document location.

Attributes:
documentChunk: The chunk-level location data.
"""

documentChunk: DocumentChunkLocationInner


class DocumentPageLocation(TypedDict, total=False):
"""Tagged union wrapper for page-level document location.

Attributes:
documentPage: The page-level location data.
"""

documentPage: DocumentPageLocationInner


class WebLocation(TypedDict, total=False):
"""Tagged union wrapper for web-based location.

Attributes:
web: The web location data.
"""

web: WebLocationInner


class SearchResultLocation(TypedDict, total=False):
"""Tagged union wrapper for search result location.

Attributes:
searchResultLocation: The search result location data.
"""

searchResultLocation: SearchResultLocationInner


# Union type for citation locations - tagged union where exactly one key is present
CitationLocation = Union[
DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation, WebLocation, SearchResultLocation
]


class CitationSourceContent(TypedDict, total=False):
Expand Down
Loading